diff --git a/agent/agentssh/x11_test.go b/agent/agentssh/x11_test.go index 83af8a2f83838..2f2c657f65036 100644 --- a/agent/agentssh/x11_test.go +++ b/agent/agentssh/x11_test.go @@ -135,7 +135,7 @@ func TestServer_X11_EvictionLRU(t *testing.T) { t.Skip("X11 forwarding is only supported on Linux") } - ctx := testutil.Context(t, testutil.WaitLong) + ctx := testutil.Context(t, testutil.WaitSuperLong) logger := testutil.Logger(t) fs := afero.NewMemMapFs() @@ -238,7 +238,9 @@ func TestServer_X11_EvictionLRU(t *testing.T) { payload := "hello world" go func() { conn, err := inproc.Dial(ctx, testutil.NewAddr("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+agentssh.X11DefaultDisplayOffset))) - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } _, err = conn.Write([]byte(payload)) assert.NoError(t, err) _ = conn.Close() diff --git a/cli/server.go b/cli/server.go index 26d0c8f110403..f9e744761b22e 100644 --- a/cli/server.go +++ b/cli/server.go @@ -55,6 +55,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" + "github.com/coder/coder/v2/coderd/pproflabel" "github.com/coder/pretty" "github.com/coder/quartz" "github.com/coder/retry" @@ -1459,14 +1460,14 @@ func newProvisionerDaemon( tracer := coderAPI.TracerProvider.Tracer(tracing.TracerName) terraformClient, terraformServer := drpcsdk.MemTransportPipe() wg.Add(1) - go func() { + pproflabel.Go(ctx, pproflabel.Service(pproflabel.ServiceTerraformProvisioner), func(ctx context.Context) { defer wg.Done() <-ctx.Done() _ = terraformClient.Close() _ = terraformServer.Close() - }() + }) wg.Add(1) - go func() { + pproflabel.Go(ctx, pproflabel.Service(pproflabel.ServiceTerraformProvisioner), func(ctx context.Context) { defer wg.Done() defer cancel() @@ -1485,7 +1486,7 @@ func newProvisionerDaemon( default: } } - }() + }) connector[string(database.ProvisionerTypeTerraform)] = sdkproto.NewDRPCProvisionerClient(terraformClient) default: diff --git a/coderd/autobuild/lifecycle_executor.go b/coderd/autobuild/lifecycle_executor.go index 234a72de04c50..16072e6517125 100644 --- a/coderd/autobuild/lifecycle_executor.go +++ b/coderd/autobuild/lifecycle_executor.go @@ -20,6 +20,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/files" + "github.com/coder/coder/v2/coderd/pproflabel" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/database" @@ -107,10 +108,10 @@ func (e *Executor) WithStatsChannel(ch chan<- Stats) *Executor { // tick from its channel. It will stop when its context is Done, or when // its channel is closed. func (e *Executor) Run() { - go func() { + pproflabel.Go(e.ctx, pproflabel.Service(pproflabel.ServiceLifecycles), func(ctx context.Context) { for { select { - case <-e.ctx.Done(): + case <-ctx.Done(): return case t, ok := <-e.tick: if !ok { @@ -120,15 +121,15 @@ func (e *Executor) Run() { e.metrics.autobuildExecutionDuration.Observe(stats.Elapsed.Seconds()) if e.statsCh != nil { select { - case <-e.ctx.Done(): + case <-ctx.Done(): return case e.statsCh <- stats: } } - e.log.Debug(e.ctx, "run stats", slog.F("elapsed", stats.Elapsed), slog.F("transitions", stats.Transitions)) + e.log.Debug(ctx, "run stats", slog.F("elapsed", stats.Elapsed), slog.F("transitions", stats.Transitions)) } } - }() + }) } func (e *Executor) runOnce(t time.Time) Stats { diff --git a/coderd/coderd.go b/coderd/coderd.go index 26bf4a7bf9b63..928fa21a95242 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -852,6 +852,7 @@ func New(options *Options) *API { r.Use( httpmw.Recover(api.Logger), + httpmw.WithProfilingLabels, tracing.StatusWriterMiddleware, tracing.Middleware(api.TracerProvider), httpmw.AttachRequestID, @@ -1415,7 +1416,8 @@ func New(options *Options) *API { r.Get("/timings", api.workspaceTimings) r.Route("/acl", func(r chi.Router) { r.Use( - httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceSharing)) + httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceSharing), + ) r.Patch("/", api.patchWorkspaceACL) }) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 99dd9833fa5d6..4e752399e08eb 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -5376,6 +5376,26 @@ func (q *querier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg databa return q.db.UpsertWorkspaceAppAuditSession(ctx, arg) } +func (q *querier) ValidateGroupIDs(ctx context.Context, groupIDs []uuid.UUID) (database.ValidateGroupIDsRow, error) { + // This check is probably overly restrictive, but the "correct" check isn't + // necessarily obvious. It's only used as a verification check for ACLs right + // now, which are performed as system. + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { + return database.ValidateGroupIDsRow{}, err + } + return q.db.ValidateGroupIDs(ctx, groupIDs) +} + +func (q *querier) ValidateUserIDs(ctx context.Context, userIDs []uuid.UUID) (database.ValidateUserIDsRow, error) { + // This check is probably overly restrictive, but the "correct" check isn't + // necessarily obvious. It's only used as a verification check for ACLs right + // now, which are performed as system. + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { + return database.ValidateUserIDsRow{}, err + } + return q.db.ValidateUserIDs(ctx, userIDs) +} + func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) { // TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier. return q.GetTemplatesWithFilter(ctx, arg) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 66a477ebfbaba..deca01456244f 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -623,6 +623,11 @@ func (s *MethodTestSuite) TestGroup() { ID: g.ID, }).Asserts(g, policy.ActionUpdate) })) + s.Run("ValidateGroupIDs", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + g := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + check.Args([]uuid.UUID{g.ID}).Asserts(rbac.ResourceSystem, policy.ActionRead) + })) } func (s *MethodTestSuite) TestProvisionerJob() { @@ -2077,6 +2082,10 @@ func (s *MethodTestSuite) TestUser() { Interval: int32((time.Hour * 24).Seconds()), }).Asserts(rbac.ResourceUser, policy.ActionRead) })) + s.Run("ValidateUserIDs", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args([]uuid.UUID{u.ID}).Asserts(rbac.ResourceSystem, policy.ActionRead) + })) } func (s *MethodTestSuite) TestWorkspace() { diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index d0cd0d1ab797d..bbed6b55346c8 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -3372,6 +3372,20 @@ func (m queryMetricsStore) UpsertWorkspaceAppAuditSession(ctx context.Context, a return r0, r1 } +func (m queryMetricsStore) ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (database.ValidateGroupIDsRow, error) { + start := time.Now() + r0, r1 := m.s.ValidateGroupIDs(ctx, groupIds) + m.queryLatencies.WithLabelValues("ValidateGroupIDs").Observe(time.Since(start).Seconds()) + return r0, r1 +} + +func (m queryMetricsStore) ValidateUserIDs(ctx context.Context, userIds []uuid.UUID) (database.ValidateUserIDsRow, error) { + start := time.Now() + r0, r1 := m.s.ValidateUserIDs(ctx, userIds) + m.queryLatencies.WithLabelValues("ValidateUserIDs").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { start := time.Now() templates, err := m.s.GetAuthorizedTemplates(ctx, arg, prepared) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index e88763ba1eb74..e1d40f12eb521 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -7159,6 +7159,36 @@ func (mr *MockStoreMockRecorder) UpsertWorkspaceAppAuditSession(ctx, arg any) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertWorkspaceAppAuditSession", reflect.TypeOf((*MockStore)(nil).UpsertWorkspaceAppAuditSession), ctx, arg) } +// ValidateGroupIDs mocks base method. +func (m *MockStore) ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (database.ValidateGroupIDsRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateGroupIDs", ctx, groupIds) + ret0, _ := ret[0].(database.ValidateGroupIDsRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ValidateGroupIDs indicates an expected call of ValidateGroupIDs. +func (mr *MockStoreMockRecorder) ValidateGroupIDs(ctx, groupIds any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateGroupIDs", reflect.TypeOf((*MockStore)(nil).ValidateGroupIDs), ctx, groupIds) +} + +// ValidateUserIDs mocks base method. +func (m *MockStore) ValidateUserIDs(ctx context.Context, userIds []uuid.UUID) (database.ValidateUserIDsRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateUserIDs", ctx, userIds) + ret0, _ := ret[0].(database.ValidateUserIDsRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ValidateUserIDs indicates an expected call of ValidateUserIDs. +func (mr *MockStoreMockRecorder) ValidateUserIDs(ctx, userIds any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateUserIDs", reflect.TypeOf((*MockStore)(nil).ValidateUserIDs), ctx, userIds) +} + // Wrappers mocks base method. func (m *MockStore) Wrappers() []string { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index a50bb0bb2192a..1ea4ae5376f80 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -689,6 +689,8 @@ type sqlcQuerier interface { // was started. This means that a new row was inserted (no previous session) or // the updated_at is older than stale interval. UpsertWorkspaceAppAuditSession(ctx context.Context, arg UpsertWorkspaceAppAuditSessionParams) (bool, error) + ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (ValidateGroupIDsRow, error) + ValidateUserIDs(ctx context.Context, userIds []uuid.UUID) (ValidateUserIDsRow, error) } var _ sqlcQuerier = (*sqlQuerier)(nil) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index b078e2dbb29c0..4adc936683067 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2869,6 +2869,37 @@ func (q *sqlQuerier) UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDPar return i, err } +const validateGroupIDs = `-- name: ValidateGroupIDs :one +WITH input AS ( + SELECT + unnest($1::uuid[]) AS id +) +SELECT + array_agg(input.id)::uuid[] as invalid_group_ids, + COUNT(*) = 0 as ok +FROM + -- Preserve rows where there is not a matching left (groups) row for each + -- right (input) row... + groups + RIGHT JOIN input ON groups.id = input.id +WHERE + -- ...so that we can retain exactly those rows where an input ID does not + -- match an existing group. + groups.id IS NULL +` + +type ValidateGroupIDsRow struct { + InvalidGroupIds []uuid.UUID `db:"invalid_group_ids" json:"invalid_group_ids"` + Ok bool `db:"ok" json:"ok"` +} + +func (q *sqlQuerier) ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (ValidateGroupIDsRow, error) { + row := q.db.QueryRowContext(ctx, validateGroupIDs, pq.Array(groupIds)) + var i ValidateGroupIDsRow + err := row.Scan(pq.Array(&i.InvalidGroupIds), &i.Ok) + return i, err +} + const getTemplateAppInsights = `-- name: GetTemplateAppInsights :many WITH -- Create a list of all unique apps by template, this is used to @@ -14792,6 +14823,39 @@ func (q *sqlQuerier) UpdateUserThemePreference(ctx context.Context, arg UpdateUs return i, err } +const validateUserIDs = `-- name: ValidateUserIDs :one +WITH input AS ( + SELECT + unnest($1::uuid[]) AS id +) +SELECT + array_agg(input.id)::uuid[] as invalid_user_ids, + COUNT(*) = 0 as ok +FROM + -- Preserve rows where there is not a matching left (users) row for each + -- right (input) row... + users + RIGHT JOIN input ON users.id = input.id +WHERE + -- ...so that we can retain exactly those rows where an input ID does not + -- match an existing user... + users.id IS NULL OR + -- ...or that only matches a user that was deleted. + users.deleted = true +` + +type ValidateUserIDsRow struct { + InvalidUserIds []uuid.UUID `db:"invalid_user_ids" json:"invalid_user_ids"` + Ok bool `db:"ok" json:"ok"` +} + +func (q *sqlQuerier) ValidateUserIDs(ctx context.Context, userIds []uuid.UUID) (ValidateUserIDsRow, error) { + row := q.db.QueryRowContext(ctx, validateUserIDs, pq.Array(userIds)) + var i ValidateUserIDsRow + err := row.Scan(pq.Array(&i.InvalidUserIds), &i.Ok) + return i, err +} + const getWorkspaceAgentDevcontainersByAgentID = `-- name: GetWorkspaceAgentDevcontainersByAgentID :many SELECT id, workspace_agent_id, created_at, workspace_folder, config_path, name diff --git a/coderd/database/queries/groups.sql b/coderd/database/queries/groups.sql index 48a5ba5c79968..3413e5832e27d 100644 --- a/coderd/database/queries/groups.sql +++ b/coderd/database/queries/groups.sql @@ -8,6 +8,24 @@ WHERE LIMIT 1; +-- name: ValidateGroupIDs :one +WITH input AS ( + SELECT + unnest(@group_ids::uuid[]) AS id +) +SELECT + array_agg(input.id)::uuid[] as invalid_group_ids, + COUNT(*) = 0 as ok +FROM + -- Preserve rows where there is not a matching left (groups) row for each + -- right (input) row... + groups + RIGHT JOIN input ON groups.id = input.id +WHERE + -- ...so that we can retain exactly those rows where an input ID does not + -- match an existing group. + groups.id IS NULL; + -- name: GetGroupByOrgAndName :one SELECT * diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index eece2f96512ea..0b6e52d6bc918 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -25,6 +25,26 @@ WHERE LIMIT 1; +-- name: ValidateUserIDs :one +WITH input AS ( + SELECT + unnest(@user_ids::uuid[]) AS id +) +SELECT + array_agg(input.id)::uuid[] as invalid_user_ids, + COUNT(*) = 0 as ok +FROM + -- Preserve rows where there is not a matching left (users) row for each + -- right (input) row... + users + RIGHT JOIN input ON users.id = input.id +WHERE + -- ...so that we can retain exactly those rows where an input ID does not + -- match an existing user... + users.id IS NULL OR + -- ...or that only matches a user that was deleted. + users.deleted = true; + -- name: GetUsersByIDs :many -- This shouldn't check for deleted, because it's frequently used -- to look up references to actions. eg. a user could build a workspace diff --git a/coderd/httpmw/pprof.go b/coderd/httpmw/pprof.go new file mode 100644 index 0000000000000..eee3e9c9fdbe1 --- /dev/null +++ b/coderd/httpmw/pprof.go @@ -0,0 +1,30 @@ +package httpmw + +import ( + "context" + "net/http" + "runtime/pprof" + + "github.com/coder/coder/v2/coderd/pproflabel" +) + +// WithProfilingLabels adds a pprof label to all http request handlers. This is +// primarily used to determine if load is coming from background jobs, or from +// http traffic. +func WithProfilingLabels(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Label to differentiate between http and websocket requests. Websocket requests + // are assumed to be long-lived and more resource consuming. + requestType := "http" + if r.Header.Get("Upgrade") == "websocket" { + requestType = "websocket" + } + + pprof.Do(ctx, pproflabel.Service(pproflabel.ServiceHTTPServer, "request_type", requestType), func(ctx context.Context) { + r = r.WithContext(ctx) + next.ServeHTTP(rw, r) + }) + }) +} diff --git a/coderd/pproflabel/pproflabel.go b/coderd/pproflabel/pproflabel.go new file mode 100644 index 0000000000000..cd803b0f1baea --- /dev/null +++ b/coderd/pproflabel/pproflabel.go @@ -0,0 +1,25 @@ +package pproflabel + +import ( + "context" + "runtime/pprof" +) + +// Go is just a convince wrapper to set off a labeled goroutine. +func Go(ctx context.Context, labels pprof.LabelSet, f func(context.Context)) { + go pprof.Do(ctx, labels, f) +} + +const ( + ServiceTag = "service" + + ServiceHTTPServer = "http-api" + ServiceLifecycles = "lifecycle-executor" + ServiceMetricCollector = "metrics-collector" + ServicePrebuildReconciler = "prebuilds-reconciler" + ServiceTerraformProvisioner = "terraform-provisioner" +) + +func Service(name string, pairs ...string) pprof.LabelSet { + return pprof.Labels(append([]string{ServiceTag, name}, pairs...)...) +} diff --git a/coderd/prometheusmetrics/insights/metricscollector.go b/coderd/prometheusmetrics/insights/metricscollector.go index 41d3a0220f391..a095968526ca8 100644 --- a/coderd/prometheusmetrics/insights/metricscollector.go +++ b/coderd/prometheusmetrics/insights/metricscollector.go @@ -14,6 +14,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/pproflabel" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/codersdk" ) @@ -158,7 +159,7 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) { }) } - go func() { + pproflabel.Go(ctx, pproflabel.Service(pproflabel.ServiceMetricCollector), func(ctx context.Context) { defer close(done) defer ticker.Stop() for { @@ -170,7 +171,7 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) { doTick() } } - }() + }) return func() { closeFunc() <-done diff --git a/coderd/rbac/acl/updatevalidator.go b/coderd/rbac/acl/updatevalidator.go new file mode 100644 index 0000000000000..9785609f2e33a --- /dev/null +++ b/coderd/rbac/acl/updatevalidator.go @@ -0,0 +1,130 @@ +package acl + +import ( + "context" + "fmt" + + "github.com/google/uuid" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/codersdk" +) + +type UpdateValidator[Role codersdk.WorkspaceRole | codersdk.TemplateRole] interface { + // Users should return a map from user UUIDs (as strings) to the role they + // are being assigned. Additionally, it should return a string that will be + // used as the field name for the ValidationErrors returned from Validate. + Users() (map[string]Role, string) + // Groups should return a map from group UUIDs (as strings) to the role they + // are being assigned. Additionally, it should return a string that will be + // used as the field name for the ValidationErrors returned from Validate. + Groups() (map[string]Role, string) + // ValidateRole should return an error that will be used in the + // ValidationError if the role is invalid for the corresponding resource type. + ValidateRole(role Role) error +} + +func Validate[Role codersdk.WorkspaceRole | codersdk.TemplateRole]( + ctx context.Context, + db database.Store, + v UpdateValidator[Role], +) []codersdk.ValidationError { + // nolint:gocritic // Validate requires full read access to users and groups + ctx = dbauthz.AsSystemRestricted(ctx) + var validErrs []codersdk.ValidationError + + groupRoles, groupsField := v.Groups() + groupIDs := make([]uuid.UUID, 0, len(groupRoles)) + for idStr, role := range groupRoles { + // Validate the provided role names + if err := v.ValidateRole(role); err != nil { + validErrs = append(validErrs, codersdk.ValidationError{ + Field: groupsField, + Detail: err.Error(), + }) + } + // Validate that the IDs are UUIDs + id, err := uuid.Parse(idStr) + if err != nil { + validErrs = append(validErrs, codersdk.ValidationError{ + Field: groupsField, + Detail: fmt.Sprintf("%v is not a valid UUID.", idStr), + }) + continue + } + // Don't check if the ID exists when setting the role to + // WorkspaceRoleDeleted or TemplateRoleDeleted. They might've existing at + // some point and got deleted. If we report that as an error here then they + // can't be removed. + if string(role) == "" { + continue + } + groupIDs = append(groupIDs, id) + } + + // Validate that the groups exist + groupValidation, err := db.ValidateGroupIDs(ctx, groupIDs) + if err != nil { + validErrs = append(validErrs, codersdk.ValidationError{ + Field: groupsField, + Detail: fmt.Sprintf("failed to validate group IDs: %v", err.Error()), + }) + } + if !groupValidation.Ok { + for _, id := range groupValidation.InvalidGroupIds { + validErrs = append(validErrs, codersdk.ValidationError{ + Field: groupsField, + Detail: fmt.Sprintf("group with ID %v does not exist", id), + }) + } + } + + userRoles, usersField := v.Users() + userIDs := make([]uuid.UUID, 0, len(userRoles)) + for idStr, role := range userRoles { + // Validate the provided role names + if err := v.ValidateRole(role); err != nil { + validErrs = append(validErrs, codersdk.ValidationError{ + Field: usersField, + Detail: err.Error(), + }) + } + // Validate that the IDs are UUIDs + id, err := uuid.Parse(idStr) + if err != nil { + validErrs = append(validErrs, codersdk.ValidationError{ + Field: usersField, + Detail: fmt.Sprintf("%v is not a valid UUID.", idStr), + }) + continue + } + // Don't check if the ID exists when setting the role to + // WorkspaceRoleDeleted or TemplateRoleDeleted. They might've existing at + // some point and got deleted. If we report that as an error here then they + // can't be removed. + if string(role) == "" { + continue + } + userIDs = append(userIDs, id) + } + + // Validate that the groups exist + userValidation, err := db.ValidateUserIDs(ctx, userIDs) + if err != nil { + validErrs = append(validErrs, codersdk.ValidationError{ + Field: usersField, + Detail: fmt.Sprintf("failed to validate user IDs: %v", err.Error()), + }) + } + if !userValidation.Ok { + for _, id := range userValidation.InvalidUserIds { + validErrs = append(validErrs, codersdk.ValidationError{ + Field: usersField, + Detail: fmt.Sprintf("user with ID %v does not exist", id), + }) + } + } + + return validErrs +} diff --git a/coderd/rbac/acl/updatevalidator_test.go b/coderd/rbac/acl/updatevalidator_test.go new file mode 100644 index 0000000000000..0e394370b1356 --- /dev/null +++ b/coderd/rbac/acl/updatevalidator_test.go @@ -0,0 +1,91 @@ +package acl_test + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/rbac/acl" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +func TestOK(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + o := dbgen.Organization(t, db, database.Organization{}) + g := dbgen.Group(t, db, database.Group{OrganizationID: o.ID}) + u := dbgen.User(t, db, database.User{}) + ctx := testutil.Context(t, testutil.WaitShort) + + update := codersdk.UpdateWorkspaceACL{ + UserRoles: map[string]codersdk.WorkspaceRole{ + u.ID.String(): codersdk.WorkspaceRoleAdmin, + // An unknown ID is allowed if and only if the specified role is either + // codersdk.WorkspaceRoleDeleted or codersdk.TemplateRoleDeleted. + uuid.NewString(): codersdk.WorkspaceRoleDeleted, + }, + GroupRoles: map[string]codersdk.WorkspaceRole{ + g.ID.String(): codersdk.WorkspaceRoleAdmin, + // An unknown ID is allowed if and only if the specified role is either + // codersdk.WorkspaceRoleDeleted or codersdk.TemplateRoleDeleted. + uuid.NewString(): codersdk.WorkspaceRoleDeleted, + }, + } + errors := acl.Validate(ctx, db, coderd.WorkspaceACLUpdateValidator(update)) + require.Empty(t, errors) +} + +func TestDeniesUnknownIDs(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + + update := codersdk.UpdateWorkspaceACL{ + UserRoles: map[string]codersdk.WorkspaceRole{ + uuid.NewString(): codersdk.WorkspaceRoleAdmin, + }, + GroupRoles: map[string]codersdk.WorkspaceRole{ + uuid.NewString(): codersdk.WorkspaceRoleAdmin, + }, + } + errors := acl.Validate(ctx, db, coderd.WorkspaceACLUpdateValidator(update)) + require.Len(t, errors, 2) + require.Equal(t, errors[0].Field, "group_roles") + require.ErrorContains(t, errors[0], "does not exist") + require.Equal(t, errors[1].Field, "user_roles") + require.ErrorContains(t, errors[1], "does not exist") +} + +func TestDeniesUnknownRolesAndInvalidIDs(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + + update := codersdk.UpdateWorkspaceACL{ + UserRoles: map[string]codersdk.WorkspaceRole{ + "Quifrey": "level 5", + }, + GroupRoles: map[string]codersdk.WorkspaceRole{ + "apprentices": "level 2", + }, + } + errors := acl.Validate(ctx, db, coderd.WorkspaceACLUpdateValidator(update)) + require.Len(t, errors, 4) + require.Equal(t, errors[0].Field, "group_roles") + require.ErrorContains(t, errors[0], "role \"level 2\" is not a valid workspace role") + require.Equal(t, errors[1].Field, "group_roles") + require.ErrorContains(t, errors[1], "not a valid UUID") + require.Equal(t, errors[2].Field, "user_roles") + require.ErrorContains(t, errors[2], "role \"level 5\" is not a valid workspace role") + require.Equal(t, errors[3].Field, "user_roles") + require.ErrorContains(t, errors[3], "not a valid UUID") +} diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 8d1376e7e6939..6da85c7608ca4 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -32,6 +32,7 @@ import ( "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/prebuilds" "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/rbac/acl" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/schedule/cron" @@ -2086,17 +2087,10 @@ func (api *API) patchWorkspaceACL(rw http.ResponseWriter, r *http.Request) { return } - validErrs := validateWorkspaceACLPerms(ctx, api.Database, req.UserRoles, "user_roles") - validErrs = append(validErrs, validateWorkspaceACLPerms( - ctx, - api.Database, - req.GroupRoles, - "group_roles", - )...) - + validErrs := acl.Validate(ctx, api.Database, WorkspaceACLUpdateValidator(req)) if len(validErrs) > 0 { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid request to update template metadata!", + Message: "Invalid request to update workspace ACL", Validations: validErrs, }) return @@ -2492,50 +2486,28 @@ func (api *API) publishWorkspaceAgentLogsUpdate(ctx context.Context, workspaceAg } } -func validateWorkspaceACLPerms(ctx context.Context, db database.Store, perms map[string]codersdk.WorkspaceRole, field string) []codersdk.ValidationError { - // nolint:gocritic // Validate requires full read access to users and groups - ctx = dbauthz.AsSystemRestricted(ctx) - var validErrs []codersdk.ValidationError - for idStr, role := range perms { - if err := validateWorkspaceRole(role); err != nil { - validErrs = append(validErrs, codersdk.ValidationError{Field: field, Detail: err.Error()}) - continue - } +type WorkspaceACLUpdateValidator codersdk.UpdateWorkspaceACL - id, err := uuid.Parse(idStr) - if err != nil { - validErrs = append(validErrs, codersdk.ValidationError{Field: field, Detail: idStr + "is not a valid UUID."}) - continue - } +var ( + workspaceACLUpdateUsersFieldName = "user_roles" + workspaceACLUpdateGroupsFieldName = "group_roles" +) - switch field { - case "user_roles": - // TODO(lilac): put this back after Kirby button shenanigans are over - // This could get slow if we get a ton of user perm updates. - // _, err = db.GetUserByID(ctx, id) - // if err != nil { - // validErrs = append(validErrs, codersdk.ValidationError{Field: field, Detail: fmt.Sprintf("Failed to find resource with ID %q: %v", idStr, err.Error())}) - // continue - // } - case "group_roles": - // This could get slow if we get a ton of group perm updates. - _, err = db.GetGroupByID(ctx, id) - if err != nil { - validErrs = append(validErrs, codersdk.ValidationError{Field: field, Detail: fmt.Sprintf("Failed to find resource with ID %q: %v", idStr, err.Error())}) - continue - } - default: - validErrs = append(validErrs, codersdk.ValidationError{Field: field, Detail: "invalid field"}) - } - } +// WorkspaceACLUpdateValidator implements acl.UpdateValidator[codersdk.WorkspaceRole] +var _ acl.UpdateValidator[codersdk.WorkspaceRole] = WorkspaceACLUpdateValidator{} + +func (w WorkspaceACLUpdateValidator) Users() (map[string]codersdk.WorkspaceRole, string) { + return w.UserRoles, workspaceACLUpdateUsersFieldName +} - return validErrs +func (w WorkspaceACLUpdateValidator) Groups() (map[string]codersdk.WorkspaceRole, string) { + return w.GroupRoles, workspaceACLUpdateGroupsFieldName } -func validateWorkspaceRole(role codersdk.WorkspaceRole) error { +func (WorkspaceACLUpdateValidator) ValidateRole(role codersdk.WorkspaceRole) error { actions := db2sdk.WorkspaceRoleActions(role) if len(actions) == 0 && role != codersdk.WorkspaceRoleDeleted { - return xerrors.Errorf("role %q is not a valid Workspace role", role) + return xerrors.Errorf("role %q is not a valid workspace role", role) } return nil diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 9583e14cd7fd3..40569ead70658 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -12,20 +12,20 @@ import ( "sync" "time" - "github.com/coder/quartz" - "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/coderd/appearance" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/idpsync" agplportsharing "github.com/coder/coder/v2/coderd/portsharing" + "github.com/coder/coder/v2/coderd/pproflabel" agplprebuilds "github.com/coder/coder/v2/coderd/prebuilds" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/coder/v2/enterprise/coderd/connectionlog" "github.com/coder/coder/v2/enterprise/coderd/enidpsync" "github.com/coder/coder/v2/enterprise/coderd/portsharing" + "github.com/coder/quartz" "golang.org/x/xerrors" "tailscale.com/tailcfg" @@ -903,7 +903,9 @@ func (api *API) updateEntitlements(ctx context.Context) error { } api.AGPL.PrebuildsReconciler.Store(&reconciler) - go reconciler.Run(context.Background()) + // TODO: Should this context be the api.ctx context? To cancel when + // the API (and entire app) is closed via shutdown? + pproflabel.Go(context.Background(), pproflabel.Service(pproflabel.ServicePrebuildReconciler), reconciler.Run) api.AGPL.PrebuildsClaimer.Store(&claimer) } diff --git a/enterprise/coderd/templates.go b/enterprise/coderd/templates.go index 438a7cfd5c65f..07323dce3c7e6 100644 --- a/enterprise/coderd/templates.go +++ b/enterprise/coderd/templates.go @@ -1,7 +1,6 @@ package coderd import ( - "context" "database/sql" "fmt" "net/http" @@ -15,6 +14,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/rbac/acl" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/codersdk" @@ -208,17 +208,10 @@ func (api *API) patchTemplateACL(rw http.ResponseWriter, r *http.Request) { return } - validErrs := validateTemplateACLPerms(ctx, api.Database, req.UserPerms, "user_perms") - validErrs = append(validErrs, validateTemplateACLPerms( - ctx, - api.Database, - req.GroupPerms, - "group_perms", - )...) - + validErrs := acl.Validate(ctx, api.Database, TemplateACLUpdateValidator(req)) if len(validErrs) > 0 { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid request to update template metadata!", + Message: "Invalid request to update template ACL", Validations: validErrs, }) return @@ -273,43 +266,31 @@ func (api *API) patchTemplateACL(rw http.ResponseWriter, r *http.Request) { }) } -func validateTemplateACLPerms(ctx context.Context, db database.Store, perms map[string]codersdk.TemplateRole, field string) []codersdk.ValidationError { - // nolint:gocritic // Validate requires full read access to users and groups - ctx = dbauthz.AsSystemRestricted(ctx) - var validErrs []codersdk.ValidationError - for idStr, role := range perms { - if err := validateTemplateRole(role); err != nil { - validErrs = append(validErrs, codersdk.ValidationError{Field: field, Detail: err.Error()}) - continue - } +type TemplateACLUpdateValidator codersdk.UpdateTemplateACL - id, err := uuid.Parse(idStr) - if err != nil { - validErrs = append(validErrs, codersdk.ValidationError{Field: field, Detail: idStr + "is not a valid UUID."}) - continue - } +var ( + templateACLUpdateUsersFieldName = "user_perms" + templateACLUpdateGroupsFieldName = "group_perms" +) - switch field { - case "user_perms": - // This could get slow if we get a ton of user perm updates. - _, err = db.GetUserByID(ctx, id) - if err != nil { - validErrs = append(validErrs, codersdk.ValidationError{Field: field, Detail: fmt.Sprintf("Failed to find resource with ID %q: %v", idStr, err.Error())}) - continue - } - case "group_perms": - // This could get slow if we get a ton of group perm updates. - _, err = db.GetGroupByID(ctx, id) - if err != nil { - validErrs = append(validErrs, codersdk.ValidationError{Field: field, Detail: fmt.Sprintf("Failed to find resource with ID %q: %v", idStr, err.Error())}) - continue - } - default: - validErrs = append(validErrs, codersdk.ValidationError{Field: field, Detail: "invalid field"}) - } +// TemplateACLUpdateValidator implements acl.UpdateValidator[codersdk.TemplateRole] +var _ acl.UpdateValidator[codersdk.TemplateRole] = TemplateACLUpdateValidator{} + +func (w TemplateACLUpdateValidator) Users() (map[string]codersdk.TemplateRole, string) { + return w.UserPerms, templateACLUpdateUsersFieldName +} + +func (w TemplateACLUpdateValidator) Groups() (map[string]codersdk.TemplateRole, string) { + return w.GroupPerms, templateACLUpdateGroupsFieldName +} + +func (TemplateACLUpdateValidator) ValidateRole(role codersdk.TemplateRole) error { + actions := db2sdk.TemplateRoleActions(role) + if len(actions) == 0 && role != codersdk.TemplateRoleDeleted { + return xerrors.Errorf("role %q is not a valid template role", role) } - return validErrs + return nil } func convertTemplateUsers(tus []database.TemplateUser, orgIDsByUserIDs map[uuid.UUID][]uuid.UUID) []codersdk.TemplateUser { @@ -325,15 +306,6 @@ func convertTemplateUsers(tus []database.TemplateUser, orgIDsByUserIDs map[uuid. return users } -func validateTemplateRole(role codersdk.TemplateRole) error { - actions := db2sdk.TemplateRoleActions(role) - if len(actions) == 0 && role != codersdk.TemplateRoleDeleted { - return xerrors.Errorf("role %q is not a valid Template role", role) - } - - return nil -} - func convertToTemplateRole(actions []policy.Action) codersdk.TemplateRole { switch { case len(actions) == 2 && slice.SameElements(actions, []policy.Action{policy.ActionUse, policy.ActionRead}): diff --git a/enterprise/coderd/templates_test.go b/enterprise/coderd/templates_test.go index 6c7a20f85a642..d95450e28e8aa 100644 --- a/enterprise/coderd/templates_test.go +++ b/enterprise/coderd/templates_test.go @@ -1413,13 +1413,40 @@ func TestUpdateTemplateACL(t *testing.T) { template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) req := codersdk.UpdateTemplateACL{ UserPerms: map[string]codersdk.TemplateRole{ - "hi": "admin", + "hi": codersdk.TemplateRoleAdmin, }, } ctx := testutil.Context(t, testutil.WaitLong) - //nolint:gocritic // we're testing invalid UUID so testing RBAC is not relevant here. + //nolint:gocritic // Testing ACL validation + err := client.UpdateTemplateACL(ctx, template.ID, req) + require.Error(t, err) + cerr, _ := codersdk.AsError(err) + require.Equal(t, http.StatusBadRequest, cerr.StatusCode()) + }) + + // We should report invalid UUIDs as errors + t.Run("DeleteRoleForInvalidUUID", func(t *testing.T) { + t.Parallel() + + client, user := coderdenttest.New(t, &coderdenttest.Options{LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureTemplateRBAC: 1, + }, + }}) + + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + req := codersdk.UpdateTemplateACL{ + UserPerms: map[string]codersdk.TemplateRole{ + "hi": codersdk.TemplateRoleDeleted, + }, + } + + ctx := testutil.Context(t, testutil.WaitLong) + + //nolint:gocritic // Testing ACL validation err := client.UpdateTemplateACL(ctx, template.ID, req) require.Error(t, err) cerr, _ := codersdk.AsError(err) @@ -1445,13 +1472,75 @@ func TestUpdateTemplateACL(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) - //nolint:gocritic // we're testing invalid user so testing RBAC is not relevant here. + //nolint:gocritic // Testing ACL validation err := client.UpdateTemplateACL(ctx, template.ID, req) require.Error(t, err) cerr, _ := codersdk.AsError(err) require.Equal(t, http.StatusBadRequest, cerr.StatusCode()) }) + // We should allow the special "Delete" role for valid UUIDs that don't + // correspond to a valid user, because the user might have been deleted. + t.Run("DeleteRoleForDeletedUser", func(t *testing.T) { + t.Parallel() + + client, user := coderdenttest.New(t, &coderdenttest.Options{LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureTemplateRBAC: 1, + }, + }}) + + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + + ctx := testutil.Context(t, testutil.WaitLong) + + _, deletedUser := coderdtest.CreateAnotherUser(t, client, user.OrganizationID) + //nolint:gocritic // Can't delete yourself + err := client.DeleteUser(ctx, deletedUser.ID) + require.NoError(t, err) + + req := codersdk.UpdateTemplateACL{ + UserPerms: map[string]codersdk.TemplateRole{ + deletedUser.ID.String(): codersdk.TemplateRoleDeleted, + }, + } + //nolint:gocritic // Testing ACL validation + err = client.UpdateTemplateACL(ctx, template.ID, req) + require.NoError(t, err) + }) + + t.Run("DeletedUser", func(t *testing.T) { + t.Parallel() + + client, user := coderdenttest.New(t, &coderdenttest.Options{LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureTemplateRBAC: 1, + }, + }}) + + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + + ctx := testutil.Context(t, testutil.WaitLong) + + _, deletedUser := coderdtest.CreateAnotherUser(t, client, user.OrganizationID) + //nolint:gocritic // Can't delete yourself + err := client.DeleteUser(ctx, deletedUser.ID) + require.NoError(t, err) + + req := codersdk.UpdateTemplateACL{ + UserPerms: map[string]codersdk.TemplateRole{ + deletedUser.ID.String(): codersdk.TemplateRoleAdmin, + }, + } + //nolint:gocritic // Testing ACL validation + err = client.UpdateTemplateACL(ctx, template.ID, req) + require.Error(t, err) + cerr, _ := codersdk.AsError(err) + require.Equal(t, http.StatusBadRequest, cerr.StatusCode()) + }) + t.Run("InvalidRole", func(t *testing.T) { t.Parallel() @@ -1472,7 +1561,7 @@ func TestUpdateTemplateACL(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) - //nolint:gocritic // we're testing invalid role so testing RBAC is not relevant here. + //nolint:gocritic // Testing ACL validation err := client.UpdateTemplateACL(ctx, template.ID, req) require.Error(t, err) cerr, _ := codersdk.AsError(err) diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index 69241d8aa1c17..c2ac1baf2db4e 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -333,6 +333,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) { r.Use( // TODO: @emyrk Should we standardize these in some other package? httpmw.Recover(s.Logger), + httpmw.WithProfilingLabels, tracing.StatusWriterMiddleware, tracing.Middleware(s.TracerProvider), httpmw.AttachRequestID, diff --git a/site/src/pages/TaskPage/TaskApps.stories.tsx b/site/src/pages/TaskPage/TaskApps.stories.tsx new file mode 100644 index 0000000000000..c3006e2db0a14 --- /dev/null +++ b/site/src/pages/TaskPage/TaskApps.stories.tsx @@ -0,0 +1,134 @@ +import type { Meta, StoryObj } from "@storybook/react"; +import type { WorkspaceApp } from "api/typesGenerated"; +import { + MockTasks, + MockWorkspace, + MockWorkspaceAgent, + MockWorkspaceApp, +} from "testHelpers/entities"; +import { withProxyProvider } from "testHelpers/storybook"; +import { TaskApps } from "./TaskApps"; + +const meta: Meta = { + title: "pages/TaskPage/TaskApps", + component: TaskApps, + decorators: [withProxyProvider()], + parameters: { + layout: "fullscreen", + }, +}; + +export default meta; +type Story = StoryObj; + +const mockAgentNoApps = { + ...MockWorkspaceAgent, + apps: [], +}; + +const mockExternalApp: WorkspaceApp = { + ...MockWorkspaceApp, + external: true, +}; + +const mockEmbeddedApp: WorkspaceApp = { + ...MockWorkspaceApp, + external: false, +}; + +const taskWithNoApps = { + ...MockTasks[0], + workspace: { + ...MockWorkspace, + latest_build: { + ...MockWorkspace.latest_build, + resources: [ + { + ...MockWorkspace.latest_build.resources[0], + agents: [mockAgentNoApps], + }, + ], + }, + }, +}; + +export const NoEmbeddedApps: Story = { + args: { + task: taskWithNoApps, + }, +}; + +export const WithExternalAppsOnly: Story = { + args: { + task: { + ...MockTasks[0], + workspace: { + ...MockWorkspace, + latest_build: { + ...MockWorkspace.latest_build, + resources: [ + { + ...MockWorkspace.latest_build.resources[0], + agents: [ + { + ...MockWorkspaceAgent, + apps: [mockExternalApp], + }, + ], + }, + ], + }, + }, + }, + }, +}; + +export const WithEmbeddedApps: Story = { + args: { + task: { + ...MockTasks[0], + workspace: { + ...MockWorkspace, + latest_build: { + ...MockWorkspace.latest_build, + resources: [ + { + ...MockWorkspace.latest_build.resources[0], + agents: [ + { + ...MockWorkspaceAgent, + apps: [mockEmbeddedApp], + }, + ], + }, + ], + }, + }, + }, + }, +}; + +export const WithMixedApps: Story = { + args: { + task: { + ...MockTasks[0], + workspace: { + ...MockWorkspace, + latest_build: { + ...MockWorkspace.latest_build, + resources: [ + { + ...MockWorkspace.latest_build.resources[0], + agents: [ + { + ...MockWorkspaceAgent, + apps: [mockEmbeddedApp, mockExternalApp], + }, + ], + }, + ], + }, + }, + }, + }, +}; diff --git a/site/src/pages/TaskPage/TaskApps.tsx b/site/src/pages/TaskPage/TaskApps.tsx index 0cccc8c7a01df..83cd01f37c004 100644 --- a/site/src/pages/TaskPage/TaskApps.tsx +++ b/site/src/pages/TaskPage/TaskApps.tsx @@ -1,4 +1,4 @@ -import type { WorkspaceApp } from "api/typesGenerated"; +import type { WorkspaceAgent, WorkspaceApp } from "api/typesGenerated"; import { Button } from "components/Button/Button"; import { DropdownMenu, @@ -8,6 +8,7 @@ import { } from "components/DropdownMenu/DropdownMenu"; import { ExternalImage } from "components/ExternalImage/ExternalImage"; import { InfoTooltip } from "components/InfoTooltip/InfoTooltip"; +import { Link } from "components/Link/Link"; import { ChevronDownIcon, LayoutGridIcon } from "lucide-react"; import { useAppLink } from "modules/apps/useAppLink"; import type { Task } from "modules/tasks/tasks"; @@ -15,12 +16,18 @@ import type React from "react"; import { type FC, useState } from "react"; import { Link as RouterLink } from "react-router-dom"; import { cn } from "utils/cn"; +import { docs } from "utils/docs"; import { TaskAppIFrame } from "./TaskAppIframe"; type TaskAppsProps = { task: Task; }; +type AppWithAgent = { + app: WorkspaceApp; + agent: WorkspaceAgent; +}; + export const TaskApps: FC = ({ task }) => { const agents = task.workspace.latest_build.resources .flatMap((r) => r.agents) @@ -29,43 +36,34 @@ export const TaskApps: FC = ({ task }) => { // The Chat UI app will be displayed in the sidebar, so we don't want to show // it here const apps = agents - .flatMap((a) => a?.apps) + .flatMap((agent) => + agent.apps.map((app) => ({ + app, + agent, + })), + ) .filter( - (a) => !!a && a.id !== task.workspace.latest_build.ai_task_sidebar_app_id, + ({ app }) => + !!app && app.id !== task.workspace.latest_build.ai_task_sidebar_app_id, ); - const embeddedApps = apps.filter((app) => !app.external); - const externalApps = apps.filter((app) => app.external); - - const [activeAppId, setActiveAppId] = useState(() => { - const appId = embeddedApps[0]?.id; - if (!appId) { - throw new Error("No apps found in task"); - } - return appId; - }); - - const activeApp = apps.find((app) => app.id === activeAppId); - if (!activeApp) { - throw new Error(`Active app with ID ${activeAppId} not found in task`); - } + const embeddedApps = apps.filter(({ app }) => !app.external); + const externalApps = apps.filter(({ app }) => app.external); - const agent = agents.find((a) => - a.apps.some((app) => app.id === activeAppId), + const [activeAppId, setActiveAppId] = useState( + embeddedApps[0]?.app.id, ); - if (!agent) { - throw new Error(`Agent for app ${activeAppId} not found in task workspace`); - } return (
- {embeddedApps.map((app) => ( + {embeddedApps.map(({ app, agent }) => ( { e.preventDefault(); @@ -76,73 +74,110 @@ export const TaskApps: FC = ({ task }) => {
{externalApps.length > 0 && ( -
- - - - - - {externalApps.map((app) => { - const link = useAppLink(app, { - agent, - workspace: task.workspace, - }); - - return ( - - - {app.icon ? ( - - ) : ( - - )} - {link.label} - - - ); - })} - - -
+ )}
-
- {embeddedApps.map((app) => { - return ( - - ); - })} -
+ {embeddedApps.length > 0 ? ( +
+ {embeddedApps.map(({ app }) => { + return ( + + ); + })} +
+ ) : ( +
+

+ No embedded apps found. +

+ + + + Learn how to configure apps + {" "} + for your tasks. + +
+ )}
); }; +type TaskExternalAppsDropdownProps = { + task: Task; + agents: WorkspaceAgent[]; + externalApps: AppWithAgent[]; +}; + +const TaskExternalAppsDropdown: FC = ({ + task, + agents, + externalApps, +}) => { + return ( +
+ + + + + + {externalApps.map(({ app, agent }) => { + const link = useAppLink(app, { + agent, + workspace: task.workspace, + }); + + return ( + + + {app.icon ? ( + + ) : ( + + )} + {link.label} + + + ); + })} + + +
+ ); +}; + type TaskAppTabProps = { task: Task; app: WorkspaceApp; + agent: WorkspaceAgent; active: boolean; onClick: (e: React.MouseEvent) => void; }; -const TaskAppTab: FC = ({ task, app, active, onClick }) => { - const agent = task.workspace.latest_build.resources - .flatMap((r) => r.agents) - .filter((a) => !!a) - .find((a) => a.apps.some((a) => a.id === app.id)); - - if (!agent) { - throw new Error(`Agent for app ${app.id} not found in task workspace`); - } - +const TaskAppTab: FC = ({ + task, + app, + agent, + active, + onClick, +}) => { const link = useAppLink(app, { agent, workspace: task.workspace,