Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 88 additions & 4 deletions coderd/aitasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"net/url"
"slices"
"strings"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -500,7 +501,7 @@ func (api *API) convertTasks(ctx context.Context, requesterID uuid.UUID, dbTasks
// @Security CoderSessionToken
// @Tags Experimental
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID" format(uuid)
// @Param task path string true "Task ID, or task name"
// @Success 200 {object} codersdk.Task
// @Router /api/experimental/tasks/{user}/{task} [get]
//
Expand Down Expand Up @@ -578,7 +579,7 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
// @Security CoderSessionToken
// @Tags Experimental
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID" format(uuid)
// @Param task path string true "Task ID, or task name"
// @Success 202 "Task deletion initiated"
// @Router /api/experimental/tasks/{user}/{task} [delete]
//
Expand Down Expand Up @@ -646,13 +647,96 @@ func (api *API) taskDelete(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusAccepted)
}

// @Summary Update AI task input
// @Description: EXPERIMENTAL: this endpoint is experimental and not guaranteed to be stable.
// @ID update-task-input
// @Security CoderSessionToken
// @Tags Experimental
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID, or task name"
// @Param request body codersdk.UpdateTaskInputRequest true "Update task input request"
// @Success 204
// @Router /api/experimental/tasks/{user}/{task}/input [patch]
//
// EXPERIMENTAL: This endpoint is experimental and not guaranteed to be stable.
// taskUpdateInput allows modifying a task's prompt before the agent executes it.
func (api *API) taskUpdateInput(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
task = httpmw.TaskParam(r)
auditor = api.Auditor.Load()
taskResourceInfo = audit.AdditionalFields{}
)

aReq, commitAudit := audit.InitRequest[database.TaskTable](rw, &audit.RequestParams{
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
AdditionalFields: taskResourceInfo,
})
defer commitAudit()
aReq.Old = task.TaskTable()
aReq.UpdateOrganizationID(task.OrganizationID)

var req codersdk.UpdateTaskInputRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}

if strings.TrimSpace(req.Input) == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Task input is required.",
})
return
}

var updatedTask database.TaskTable
if err := api.Database.InTx(func(tx database.Store) error {
task, err := tx.GetTaskByID(ctx, task.ID)
if err != nil {
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to fetch task.",
Detail: err.Error(),
})
}

if task.Status != database.TaskStatusPaused {
return httperror.NewResponseError(http.StatusConflict, codersdk.Response{
Message: "Unable to update task input, task must be paused.",
Detail: "Please stop the task's workspace before updating the input.",
})
}

updatedTask, err = tx.UpdateTaskPrompt(ctx, database.UpdateTaskPromptParams{
ID: task.ID,
Prompt: req.Input,
})
if err != nil {
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to update task input.",
Detail: err.Error(),
})
}

return nil
}, nil); err != nil {
httperror.WriteResponseError(ctx, rw, err)
return
}

aReq.New = updatedTask

httpapi.Write(ctx, rw, http.StatusNoContent, nil)
}

// @Summary Send input to AI task
// @Description: EXPERIMENTAL: this endpoint is experimental and not guaranteed to be stable.
// @ID send-task-input
// @Security CoderSessionToken
// @Tags Experimental
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID" format(uuid)
// @Param task path string true "Task ID, or task name"
// @Param request body codersdk.TaskSendRequest true "Task input request"
// @Success 204 "Input sent successfully"
// @Router /api/experimental/tasks/{user}/{task}/send [post]
Expand Down Expand Up @@ -726,7 +810,7 @@ func (api *API) taskSend(rw http.ResponseWriter, r *http.Request) {
// @Security CoderSessionToken
// @Tags Experimental
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID" format(uuid)
// @Param task path string true "Task ID, or task name"
// @Success 200 {object} codersdk.TaskLogsResponse
// @Router /api/experimental/tasks/{user}/{task}/logs [get]
//
Expand Down
205 changes: 205 additions & 0 deletions coderd/aitasks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
"github.com/coder/coder/v2/coderd/util/slice"
Expand Down Expand Up @@ -738,6 +739,210 @@ func TestTasks(t *testing.T) {
require.Equal(t, http.StatusBadGateway, sdkErr.StatusCode())
})
})

t.Run("UpdateInput", func(t *testing.T) {
tests := []struct {
name string
disableProvisioner bool
transition database.WorkspaceTransition
cancelTransition bool
deleteTask bool
taskInput string
wantStatus codersdk.TaskStatus
wantErr string
wantErrStatusCode int
}{
{
name: "TaskStatusInitializing",
// We want to disable the provisioner so that the task
// never gets provisioned (ensuring it stays in Initializing).
disableProvisioner: true,
taskInput: "Valid prompt",
wantStatus: codersdk.TaskStatusInitializing,
wantErr: "Unable to update",
wantErrStatusCode: http.StatusConflict,
},
{
name: "TaskStatusPaused",
transition: database.WorkspaceTransitionStop,
taskInput: "Valid prompt",
wantStatus: codersdk.TaskStatusPaused,
},
{
name: "TaskStatusError",
transition: database.WorkspaceTransitionStart,
cancelTransition: true,
taskInput: "Valid prompt",
wantStatus: codersdk.TaskStatusError,
wantErr: "Unable to update",
wantErrStatusCode: http.StatusConflict,
},
{
name: "EmptyPrompt",
transition: database.WorkspaceTransitionStop,
// We want to ensure an empty prompt is rejected.
taskInput: "",
wantStatus: codersdk.TaskStatusPaused,
wantErr: "Task input is required.",
wantErrStatusCode: http.StatusBadRequest,
},
{
name: "TaskDeleted",
transition: database.WorkspaceTransitionStop,
deleteTask: true,
taskInput: "Valid prompt",
wantErr: httpapi.ResourceNotFoundResponse.Message,
wantErrStatusCode: http.StatusNotFound,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

client, provisioner := coderdtest.NewWithProvisionerCloser(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)

template := createAITemplate(t, client, user)

if tt.disableProvisioner {
provisioner.Close()
}

// Given: We create a task
exp := codersdk.NewExperimentalClient(client)
task, err := exp.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: "initial prompt",
})
require.NoError(t, err)
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")

if !tt.disableProvisioner {
// Given: The Task is running
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)

// Given: We transition the task's workspace
build := coderdtest.CreateWorkspaceBuild(t, client, workspace, tt.transition)
if tt.cancelTransition {
// Given: We cancel the workspace build
err := client.CancelWorkspaceBuild(ctx, build.ID, codersdk.CancelWorkspaceBuildParams{})
require.NoError(t, err)

coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID)

// Then: We expect it to be canceled
build, err = client.WorkspaceBuild(ctx, build.ID)
require.NoError(t, err)
require.Equal(t, codersdk.WorkspaceStatusCanceled, build.Status)
} else {
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID)
}
}

if tt.deleteTask {
err = exp.DeleteTask(ctx, codersdk.Me, task.ID)
require.NoError(t, err)
} else {
// Given: Task has expected status
task, err = exp.TaskByID(ctx, task.ID)
require.NoError(t, err)
require.Equal(t, tt.wantStatus, task.Status)
}

// When: We attempt to update the task input
err = exp.UpdateTaskInput(ctx, task.OwnerName, task.ID, codersdk.UpdateTaskInputRequest{
Input: tt.taskInput,
})
if tt.wantErr != "" {
require.ErrorContains(t, err, tt.wantErr)

if tt.wantErrStatusCode != 0 {
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, tt.wantErrStatusCode, apiErr.StatusCode())
}

if !tt.deleteTask {
// Then: We expect the input to **not** be updated
task, err = exp.TaskByID(ctx, task.ID)
require.NoError(t, err)
require.NotEqual(t, tt.taskInput, task.InitialPrompt)
}
} else {
require.NoError(t, err)

if !tt.deleteTask {
// Then: We expect the input to be updated
task, err = exp.TaskByID(ctx, task.ID)
require.NoError(t, err)
require.Equal(t, tt.taskInput, task.InitialPrompt)
}
}
})
}

t.Run("NonExistentTask", func(t *testing.T) {
t.Parallel()

client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitShort)

exp := codersdk.NewExperimentalClient(client)

// Attempt to update prompt for non-existent task
err := exp.UpdateTaskInput(ctx, user.UserID.String(), uuid.New(), codersdk.UpdateTaskInputRequest{
Input: "Should fail",
})
require.Error(t, err)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
})

t.Run("UnauthorizedUser", func(t *testing.T) {
t.Parallel()

client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
anotherUser, _ := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
ctx := testutil.Context(t, testutil.WaitLong)

template := createAITemplate(t, client, user)

// Create a task as the first user
exp := codersdk.NewExperimentalClient(client)
task, err := exp.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: "initial prompt",
})
require.NoError(t, err)
require.True(t, task.WorkspaceID.Valid)

// Wait for workspace to complete
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)

// Stop the workspace
build := coderdtest.CreateWorkspaceBuild(t, client, workspace, database.WorkspaceTransitionStop)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID)

// Attempt to update prompt as another user should fail with 404 Not Found
otherExp := codersdk.NewExperimentalClient(anotherUser)
err = otherExp.UpdateTaskInput(ctx, task.OwnerName, task.ID, codersdk.UpdateTaskInputRequest{
Input: "Should fail - unauthorized",
})
require.Error(t, err)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
})
})
}

func TestTasksCreate(t *testing.T) {
Expand Down
Loading
Loading