Skip to content

Commit 82f525b

Browse files
feat(coderd): add task prompt modification endpoint (#20811)
This PR adds the backend implementation for modifying task prompts. Part of coder/internal#1084 ## Changes - New `UpdateTaskPrompt` database query to update task prompts - New PATCH `/api/v2/tasks/{task}/prompt` endpoint ## Notes This is part 1 of a 2-part PR stack. The frontend UI will be added in a follow-up PR based on this branch (#20812). --- 🤖 PR was written by Claude Sonnet 4.5 Thinking using [Coder Mux](https://github.com/coder/cmux) and reviewed by a human 👩
1 parent afd4043 commit 82f525b

File tree

18 files changed

+659
-37
lines changed

18 files changed

+659
-37
lines changed

coderd/aitasks.go

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net/http"
88
"net/url"
99
"slices"
10+
"strings"
1011
"time"
1112

1213
"github.com/google/uuid"
@@ -500,7 +501,7 @@ func (api *API) convertTasks(ctx context.Context, requesterID uuid.UUID, dbTasks
500501
// @Security CoderSessionToken
501502
// @Tags Experimental
502503
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
503-
// @Param task path string true "Task ID" format(uuid)
504+
// @Param task path string true "Task ID, or task name"
504505
// @Success 200 {object} codersdk.Task
505506
// @Router /api/experimental/tasks/{user}/{task} [get]
506507
//
@@ -578,7 +579,7 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
578579
// @Security CoderSessionToken
579580
// @Tags Experimental
580581
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
581-
// @Param task path string true "Task ID" format(uuid)
582+
// @Param task path string true "Task ID, or task name"
582583
// @Success 202 "Task deletion initiated"
583584
// @Router /api/experimental/tasks/{user}/{task} [delete]
584585
//
@@ -646,13 +647,96 @@ func (api *API) taskDelete(rw http.ResponseWriter, r *http.Request) {
646647
rw.WriteHeader(http.StatusAccepted)
647648
}
648649

650+
// @Summary Update AI task input
651+
// @Description: EXPERIMENTAL: this endpoint is experimental and not guaranteed to be stable.
652+
// @ID update-task-input
653+
// @Security CoderSessionToken
654+
// @Tags Experimental
655+
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
656+
// @Param task path string true "Task ID, or task name"
657+
// @Param request body codersdk.UpdateTaskInputRequest true "Update task input request"
658+
// @Success 204
659+
// @Router /api/experimental/tasks/{user}/{task}/input [patch]
660+
//
661+
// EXPERIMENTAL: This endpoint is experimental and not guaranteed to be stable.
662+
// taskUpdateInput allows modifying a task's prompt before the agent executes it.
663+
func (api *API) taskUpdateInput(rw http.ResponseWriter, r *http.Request) {
664+
var (
665+
ctx = r.Context()
666+
task = httpmw.TaskParam(r)
667+
auditor = api.Auditor.Load()
668+
taskResourceInfo = audit.AdditionalFields{}
669+
)
670+
671+
aReq, commitAudit := audit.InitRequest[database.TaskTable](rw, &audit.RequestParams{
672+
Audit: *auditor,
673+
Log: api.Logger,
674+
Request: r,
675+
Action: database.AuditActionWrite,
676+
AdditionalFields: taskResourceInfo,
677+
})
678+
defer commitAudit()
679+
aReq.Old = task.TaskTable()
680+
aReq.UpdateOrganizationID(task.OrganizationID)
681+
682+
var req codersdk.UpdateTaskInputRequest
683+
if !httpapi.Read(ctx, rw, r, &req) {
684+
return
685+
}
686+
687+
if strings.TrimSpace(req.Input) == "" {
688+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
689+
Message: "Task input is required.",
690+
})
691+
return
692+
}
693+
694+
var updatedTask database.TaskTable
695+
if err := api.Database.InTx(func(tx database.Store) error {
696+
task, err := tx.GetTaskByID(ctx, task.ID)
697+
if err != nil {
698+
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
699+
Message: "Failed to fetch task.",
700+
Detail: err.Error(),
701+
})
702+
}
703+
704+
if task.Status != database.TaskStatusPaused {
705+
return httperror.NewResponseError(http.StatusConflict, codersdk.Response{
706+
Message: "Unable to update task input, task must be paused.",
707+
Detail: "Please stop the task's workspace before updating the input.",
708+
})
709+
}
710+
711+
updatedTask, err = tx.UpdateTaskPrompt(ctx, database.UpdateTaskPromptParams{
712+
ID: task.ID,
713+
Prompt: req.Input,
714+
})
715+
if err != nil {
716+
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
717+
Message: "Failed to update task input.",
718+
Detail: err.Error(),
719+
})
720+
}
721+
722+
return nil
723+
}, nil); err != nil {
724+
httperror.WriteResponseError(ctx, rw, err)
725+
return
726+
}
727+
728+
aReq.New = updatedTask
729+
730+
httpapi.Write(ctx, rw, http.StatusNoContent, nil)
731+
}
732+
649733
// @Summary Send input to AI task
650734
// @Description: EXPERIMENTAL: this endpoint is experimental and not guaranteed to be stable.
651735
// @ID send-task-input
652736
// @Security CoderSessionToken
653737
// @Tags Experimental
654738
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
655-
// @Param task path string true "Task ID" format(uuid)
739+
// @Param task path string true "Task ID, or task name"
656740
// @Param request body codersdk.TaskSendRequest true "Task input request"
657741
// @Success 204 "Input sent successfully"
658742
// @Router /api/experimental/tasks/{user}/{task}/send [post]
@@ -726,7 +810,7 @@ func (api *API) taskSend(rw http.ResponseWriter, r *http.Request) {
726810
// @Security CoderSessionToken
727811
// @Tags Experimental
728812
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
729-
// @Param task path string true "Task ID" format(uuid)
813+
// @Param task path string true "Task ID, or task name"
730814
// @Success 200 {object} codersdk.TaskLogsResponse
731815
// @Router /api/experimental/tasks/{user}/{task}/logs [get]
732816
//

coderd/aitasks_test.go

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/coder/coder/v2/coderd/database/dbfake"
2424
"github.com/coder/coder/v2/coderd/database/dbgen"
2525
"github.com/coder/coder/v2/coderd/database/dbtime"
26+
"github.com/coder/coder/v2/coderd/httpapi"
2627
"github.com/coder/coder/v2/coderd/notifications"
2728
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
2829
"github.com/coder/coder/v2/coderd/util/slice"
@@ -738,6 +739,210 @@ func TestTasks(t *testing.T) {
738739
require.Equal(t, http.StatusBadGateway, sdkErr.StatusCode())
739740
})
740741
})
742+
743+
t.Run("UpdateInput", func(t *testing.T) {
744+
tests := []struct {
745+
name string
746+
disableProvisioner bool
747+
transition database.WorkspaceTransition
748+
cancelTransition bool
749+
deleteTask bool
750+
taskInput string
751+
wantStatus codersdk.TaskStatus
752+
wantErr string
753+
wantErrStatusCode int
754+
}{
755+
{
756+
name: "TaskStatusInitializing",
757+
// We want to disable the provisioner so that the task
758+
// never gets provisioned (ensuring it stays in Initializing).
759+
disableProvisioner: true,
760+
taskInput: "Valid prompt",
761+
wantStatus: codersdk.TaskStatusInitializing,
762+
wantErr: "Unable to update",
763+
wantErrStatusCode: http.StatusConflict,
764+
},
765+
{
766+
name: "TaskStatusPaused",
767+
transition: database.WorkspaceTransitionStop,
768+
taskInput: "Valid prompt",
769+
wantStatus: codersdk.TaskStatusPaused,
770+
},
771+
{
772+
name: "TaskStatusError",
773+
transition: database.WorkspaceTransitionStart,
774+
cancelTransition: true,
775+
taskInput: "Valid prompt",
776+
wantStatus: codersdk.TaskStatusError,
777+
wantErr: "Unable to update",
778+
wantErrStatusCode: http.StatusConflict,
779+
},
780+
{
781+
name: "EmptyPrompt",
782+
transition: database.WorkspaceTransitionStop,
783+
// We want to ensure an empty prompt is rejected.
784+
taskInput: "",
785+
wantStatus: codersdk.TaskStatusPaused,
786+
wantErr: "Task input is required.",
787+
wantErrStatusCode: http.StatusBadRequest,
788+
},
789+
{
790+
name: "TaskDeleted",
791+
transition: database.WorkspaceTransitionStop,
792+
deleteTask: true,
793+
taskInput: "Valid prompt",
794+
wantErr: httpapi.ResourceNotFoundResponse.Message,
795+
wantErrStatusCode: http.StatusNotFound,
796+
},
797+
}
798+
799+
for _, tt := range tests {
800+
t.Run(tt.name, func(t *testing.T) {
801+
t.Parallel()
802+
803+
client, provisioner := coderdtest.NewWithProvisionerCloser(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
804+
user := coderdtest.CreateFirstUser(t, client)
805+
ctx := testutil.Context(t, testutil.WaitLong)
806+
807+
template := createAITemplate(t, client, user)
808+
809+
if tt.disableProvisioner {
810+
provisioner.Close()
811+
}
812+
813+
// Given: We create a task
814+
exp := codersdk.NewExperimentalClient(client)
815+
task, err := exp.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
816+
TemplateVersionID: template.ActiveVersionID,
817+
Input: "initial prompt",
818+
})
819+
require.NoError(t, err)
820+
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
821+
822+
if !tt.disableProvisioner {
823+
// Given: The Task is running
824+
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
825+
require.NoError(t, err)
826+
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
827+
828+
// Given: We transition the task's workspace
829+
build := coderdtest.CreateWorkspaceBuild(t, client, workspace, tt.transition)
830+
if tt.cancelTransition {
831+
// Given: We cancel the workspace build
832+
err := client.CancelWorkspaceBuild(ctx, build.ID, codersdk.CancelWorkspaceBuildParams{})
833+
require.NoError(t, err)
834+
835+
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID)
836+
837+
// Then: We expect it to be canceled
838+
build, err = client.WorkspaceBuild(ctx, build.ID)
839+
require.NoError(t, err)
840+
require.Equal(t, codersdk.WorkspaceStatusCanceled, build.Status)
841+
} else {
842+
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID)
843+
}
844+
}
845+
846+
if tt.deleteTask {
847+
err = exp.DeleteTask(ctx, codersdk.Me, task.ID)
848+
require.NoError(t, err)
849+
} else {
850+
// Given: Task has expected status
851+
task, err = exp.TaskByID(ctx, task.ID)
852+
require.NoError(t, err)
853+
require.Equal(t, tt.wantStatus, task.Status)
854+
}
855+
856+
// When: We attempt to update the task input
857+
err = exp.UpdateTaskInput(ctx, task.OwnerName, task.ID, codersdk.UpdateTaskInputRequest{
858+
Input: tt.taskInput,
859+
})
860+
if tt.wantErr != "" {
861+
require.ErrorContains(t, err, tt.wantErr)
862+
863+
if tt.wantErrStatusCode != 0 {
864+
var apiErr *codersdk.Error
865+
require.ErrorAs(t, err, &apiErr)
866+
require.Equal(t, tt.wantErrStatusCode, apiErr.StatusCode())
867+
}
868+
869+
if !tt.deleteTask {
870+
// Then: We expect the input to **not** be updated
871+
task, err = exp.TaskByID(ctx, task.ID)
872+
require.NoError(t, err)
873+
require.NotEqual(t, tt.taskInput, task.InitialPrompt)
874+
}
875+
} else {
876+
require.NoError(t, err)
877+
878+
if !tt.deleteTask {
879+
// Then: We expect the input to be updated
880+
task, err = exp.TaskByID(ctx, task.ID)
881+
require.NoError(t, err)
882+
require.Equal(t, tt.taskInput, task.InitialPrompt)
883+
}
884+
}
885+
})
886+
}
887+
888+
t.Run("NonExistentTask", func(t *testing.T) {
889+
t.Parallel()
890+
891+
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
892+
user := coderdtest.CreateFirstUser(t, client)
893+
ctx := testutil.Context(t, testutil.WaitShort)
894+
895+
exp := codersdk.NewExperimentalClient(client)
896+
897+
// Attempt to update prompt for non-existent task
898+
err := exp.UpdateTaskInput(ctx, user.UserID.String(), uuid.New(), codersdk.UpdateTaskInputRequest{
899+
Input: "Should fail",
900+
})
901+
require.Error(t, err)
902+
var apiErr *codersdk.Error
903+
require.ErrorAs(t, err, &apiErr)
904+
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
905+
})
906+
907+
t.Run("UnauthorizedUser", func(t *testing.T) {
908+
t.Parallel()
909+
910+
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
911+
user := coderdtest.CreateFirstUser(t, client)
912+
anotherUser, _ := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
913+
ctx := testutil.Context(t, testutil.WaitLong)
914+
915+
template := createAITemplate(t, client, user)
916+
917+
// Create a task as the first user
918+
exp := codersdk.NewExperimentalClient(client)
919+
task, err := exp.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
920+
TemplateVersionID: template.ActiveVersionID,
921+
Input: "initial prompt",
922+
})
923+
require.NoError(t, err)
924+
require.True(t, task.WorkspaceID.Valid)
925+
926+
// Wait for workspace to complete
927+
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
928+
require.NoError(t, err)
929+
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
930+
931+
// Stop the workspace
932+
build := coderdtest.CreateWorkspaceBuild(t, client, workspace, database.WorkspaceTransitionStop)
933+
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID)
934+
935+
// Attempt to update prompt as another user should fail with 404 Not Found
936+
otherExp := codersdk.NewExperimentalClient(anotherUser)
937+
err = otherExp.UpdateTaskInput(ctx, task.OwnerName, task.ID, codersdk.UpdateTaskInputRequest{
938+
Input: "Should fail - unauthorized",
939+
})
940+
require.Error(t, err)
941+
var apiErr *codersdk.Error
942+
require.ErrorAs(t, err, &apiErr)
943+
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
944+
})
945+
})
741946
}
742947

743948
func TestTasksCreate(t *testing.T) {

0 commit comments

Comments
 (0)