Skip to content
20 changes: 14 additions & 6 deletions coderd/aitasks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,13 @@ func TestTasks(t *testing.T) {
t.Parallel()

var (
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
ctx = testutil.Context(t, testutil.WaitLong)
user = coderdtest.CreateFirstUser(t, client)
template = createAITemplate(t, client, user)
wantPrompt = "review my code"
exp = codersdk.NewExperimentalClient(client)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
ctx = testutil.Context(t, testutil.WaitLong)
user = coderdtest.CreateFirstUser(t, client)
anotherUser, _ = coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
template = createAITemplate(t, client, user)
wantPrompt = "review my code"
exp = codersdk.NewExperimentalClient(client)
)

task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
Expand Down Expand Up @@ -211,6 +212,13 @@ func TestTasks(t *testing.T) {
assert.Equal(t, taskAppID, updated.WorkspaceAppID.UUID, "workspace app id should match")
assert.NotEmpty(t, updated.WorkspaceStatus, "task status should not be empty")

// Another member user should not be able to fetch the task
_, err = codersdk.NewExperimentalClient(anotherUser).TaskByID(ctx, task.ID)
require.Error(t, err, "fetching task should fail for another member user")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())

// Stop the workspace
coderdtest.MustTransitionWorkspace(t, client, task.WorkspaceID.UUID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop)

Expand Down
4 changes: 4 additions & 0 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -2989,6 +2989,10 @@ func (q *querier) GetTaskByID(ctx context.Context, id uuid.UUID) (database.Task,
return fetch(q.log, q.auth, q.db.GetTaskByID)(ctx, id)
}

func (q *querier) GetTaskByOwnerIDAndName(ctx context.Context, arg database.GetTaskByOwnerIDAndNameParams) (database.Task, error) {
return fetch(q.log, q.auth, q.db.GetTaskByOwnerIDAndName)(ctx, arg)
}

func (q *querier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.Task, error) {
return fetch(q.log, q.auth, q.db.GetTaskByWorkspaceID)(ctx, workspaceID)
}
Expand Down
11 changes: 11 additions & 0 deletions coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2375,6 +2375,17 @@ func (s *MethodTestSuite) TestTasks() {
dbm.EXPECT().GetTaskByID(gomock.Any(), task.ID).Return(task, nil).AnyTimes()
check.Args(task.ID).Asserts(task, policy.ActionRead).Returns(task)
}))
s.Run("GetTaskByOwnerIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
task := testutil.Fake(s.T(), faker, database.Task{})
dbm.EXPECT().GetTaskByOwnerIDAndName(gomock.Any(), database.GetTaskByOwnerIDAndNameParams{
OwnerID: task.OwnerID,
Name: task.Name,
}).Return(task, nil).AnyTimes()
check.Args(database.GetTaskByOwnerIDAndNameParams{
OwnerID: task.OwnerID,
Name: task.Name,
}).Asserts(task, policy.ActionRead).Returns(task)
}))
s.Run("DeleteTask", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
task := testutil.Fake(s.T(), faker, database.Task{})
arg := database.DeleteTaskParams{
Expand Down
7 changes: 7 additions & 0 deletions coderd/database/dbmetrics/querymetrics.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions coderd/database/dbmock/dbmock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions coderd/database/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 38 additions & 0 deletions coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions coderd/database/queries/tasks.sql
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ SELECT * FROM tasks_with_status WHERE id = @id::uuid;
-- name: GetTaskByWorkspaceID :one
SELECT * FROM tasks_with_status WHERE workspace_id = @workspace_id::uuid;

-- name: GetTaskByOwnerIDAndName :one
SELECT * FROM tasks_with_status
WHERE
owner_id = @owner_id::uuid
AND deleted_at IS NULL
AND LOWER(name) = LOWER(@name::text);

-- name: ListTasks :many
SELECT * FROM tasks_with_status tws
WHERE tws.deleted_at IS NULL
Expand Down
74 changes: 63 additions & 11 deletions coderd/httpmw/taskparam.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@ package httpmw

import (
"context"
"database/sql"
"errors"
"net/http"

"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"golang.org/x/xerrors"

"cdr.dev/slog"

"github.com/coder/coder/v2/coderd/database"
Expand All @@ -23,35 +29,81 @@ func TaskParam(r *http.Request) database.Task {
return task
}

// ExtractTaskParam grabs a task from the "task" URL parameter by UUID.
// ExtractTaskParam grabs a task from the "task" URL parameter.
// It supports two lookup strategies:
// 1. Task UUID (primary)
// 2. Task name scoped to owner (secondary)
//
// This middleware depends on ExtractOrganizationMembersParam being in the chain
// to provide the owner context for name-based lookups.
func ExtractTaskParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
taskID, parsed := ParseUUIDParam(rw, r, "task")
if !parsed {

// Get the task parameter value. We can't use ParseUUIDParam here because
// we need to support non-UUID values (task names) and
// attempt all lookup strategies.
taskParam := chi.URLParam(r, "task")
if taskParam == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "\"task\" must be provided.",
})
return
}
task, err := db.GetTaskByID(ctx, taskID)

// Get owner from OrganizationMembersParam middleware for name-based lookups
members := OrganizationMembersParam(r)
ownerID := members.UserID()

task, err := fetchTaskWithFallback(ctx, db, taskParam, ownerID)
if err != nil {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
if !httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching task.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching task.",
Detail: err.Error(),
})
httpapi.ResourceNotFound(rw)
return
}

ctx = context.WithValue(ctx, taskParamContextKey{}, task)

if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil {
rlogger.WithFields(slog.F("task_id", task.ID), slog.F("task_name", task.Name))
rlogger.WithFields(
slog.F("task_id", task.ID),
slog.F("task_name", task.Name),
)
}

next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

func fetchTaskWithFallback(ctx context.Context, db database.Store, taskParam string, ownerID uuid.UUID) (database.Task, error) {
// Attempt to first lookup the task by UUID.
taskID, err := uuid.Parse(taskParam)
if err == nil {
task, err := db.GetTaskByID(ctx, taskID)
if err == nil {
return task, nil
}
// There may be a task named with a valid UUID. Fall back to name lookup in this case.
if !errors.Is(err, sql.ErrNoRows) {
return database.Task{}, xerrors.Errorf("fetch task by uuid: %w", err)
}
}

// taskParam not a valid UUID, OR valid UUID but not found, so attempt lookup by name.
task, err := db.GetTaskByOwnerIDAndName(ctx, database.GetTaskByOwnerIDAndNameParams{
OwnerID: ownerID,
Name: taskParam,
})
if err != nil {
return database.Task{}, xerrors.Errorf("fetch task by name: %w", err)
}
return task, nil
}
Loading
Loading