diff --git a/cli/server.go b/cli/server.go index 602f05d028b66..26d0c8f110403 100644 --- a/cli/server.go +++ b/cli/server.go @@ -1101,7 +1101,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. autobuildTicker := time.NewTicker(vals.AutobuildPollInterval.Value()) defer autobuildTicker.Stop() autobuildExecutor := autobuild.NewExecutor( - ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments) + ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, coderAPI.BuildUsageChecker, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments) autobuildExecutor.Run() jobReaperTicker := time.NewTicker(vals.JobReaperDetectorInterval.Value()) diff --git a/coderd/autobuild/lifecycle_executor.go b/coderd/autobuild/lifecycle_executor.go index d49bf831515d0..234a72de04c50 100644 --- a/coderd/autobuild/lifecycle_executor.go +++ b/coderd/autobuild/lifecycle_executor.go @@ -42,6 +42,7 @@ type Executor struct { templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] accessControlStore *atomic.Pointer[dbauthz.AccessControlStore] auditor *atomic.Pointer[audit.Auditor] + buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker] log slog.Logger tick <-chan time.Time statsCh chan<- Stats @@ -65,7 +66,7 @@ type Stats struct { } // New returns a new wsactions executor. -func NewExecutor(ctx context.Context, db database.Store, ps pubsub.Pubsub, fc *files.Cache, reg prometheus.Registerer, tss *atomic.Pointer[schedule.TemplateScheduleStore], auditor *atomic.Pointer[audit.Auditor], acs *atomic.Pointer[dbauthz.AccessControlStore], log slog.Logger, tick <-chan time.Time, enqueuer notifications.Enqueuer, exp codersdk.Experiments) *Executor { +func NewExecutor(ctx context.Context, db database.Store, ps pubsub.Pubsub, fc *files.Cache, reg prometheus.Registerer, tss *atomic.Pointer[schedule.TemplateScheduleStore], auditor *atomic.Pointer[audit.Auditor], acs *atomic.Pointer[dbauthz.AccessControlStore], buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker], log slog.Logger, tick <-chan time.Time, enqueuer notifications.Enqueuer, exp codersdk.Experiments) *Executor { factory := promauto.With(reg) le := &Executor{ //nolint:gocritic // Autostart has a limited set of permissions. @@ -78,6 +79,7 @@ func NewExecutor(ctx context.Context, db database.Store, ps pubsub.Pubsub, fc *f log: log.Named("autobuild"), auditor: auditor, accessControlStore: acs, + buildUsageChecker: buildUsageChecker, notificationsEnqueuer: enqueuer, reg: reg, experiments: exp, @@ -279,7 +281,7 @@ func (e *Executor) runOnce(t time.Time) Stats { } if nextTransition != "" { - builder := wsbuilder.New(ws, nextTransition). + builder := wsbuilder.New(ws, nextTransition, *e.buildUsageChecker.Load()). SetLastWorkspaceBuildInTx(&latestBuild). SetLastWorkspaceBuildJobInTx(&latestJob). Experiments(e.experiments). diff --git a/coderd/coderd.go b/coderd/coderd.go index fa10846a7d0a6..9115888fc566b 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -21,6 +21,7 @@ import ( "github.com/coder/coder/v2/coderd/oauth2provider" "github.com/coder/coder/v2/coderd/prebuilds" + "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/andybalholm/brotli" "github.com/go-chi/chi/v5" @@ -559,6 +560,13 @@ func New(options *Options) *API { // bugs that may only occur when a key isn't precached in tests and the latency cost is minimal. cryptokeys.StartRotator(ctx, options.Logger, options.Database) + // AGPL uses a no-op build usage checker as there are no license + // entitlements to enforce. This is swapped out in + // enterprise/coderd/coderd.go. + var buildUsageChecker atomic.Pointer[wsbuilder.UsageChecker] + var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{} + buildUsageChecker.Store(&noopUsageChecker) + api := &API{ ctx: ctx, cancel: cancel, @@ -579,6 +587,7 @@ func New(options *Options) *API { TemplateScheduleStore: options.TemplateScheduleStore, UserQuietHoursScheduleStore: options.UserQuietHoursScheduleStore, AccessControlStore: options.AccessControlStore, + BuildUsageChecker: &buildUsageChecker, FileCache: files.New(options.PrometheusRegistry, options.Authorizer), Experiments: experiments, WebpushDispatcher: options.WebPushDispatcher, @@ -1650,6 +1659,9 @@ type API struct { FileCache *files.Cache PrebuildsClaimer atomic.Pointer[prebuilds.Claimer] PrebuildsReconciler atomic.Pointer[prebuilds.ReconciliationOrchestrator] + // BuildUsageChecker is a pointer as it's passed around to multiple + // components. + BuildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker] UpdatesProvider tailnet.WorkspaceUpdatesProvider diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 96030b215e5dd..7085068e97ff4 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -55,6 +55,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/archive" "github.com/coder/coder/v2/coderd/files" + "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/quartz" "github.com/coder/coder/v2/coderd" @@ -364,6 +365,10 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can } connectionLogger.Store(&options.ConnectionLogger) + var buildUsageChecker atomic.Pointer[wsbuilder.UsageChecker] + var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{} + buildUsageChecker.Store(&noopUsageChecker) + ctx, cancelFunc := context.WithCancel(context.Background()) experiments := coderd.ReadExperiments(*options.Logger, options.DeploymentValues.Experiments) lifecycleExecutor := autobuild.NewExecutor( @@ -375,6 +380,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can &templateScheduleStore, &auditor, accessControlStore, + &buildUsageChecker, *options.Logger, options.AutobuildTicker, options.NotificationsEnqueuer, diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a12db9aa6919f..257cbc6e6b142 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2193,6 +2193,14 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) { return q.db.GetLogoURL(ctx) } +func (q *querier) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) { + // Must be able to read all workspaces to check usage. + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace); err != nil { + return 0, xerrors.Errorf("authorize read all workspaces: %w", err) + } + return q.db.GetManagedAgentCount(ctx, arg) +} + func (q *querier) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceNotificationMessage); err != nil { return nil, err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 2b0801024eb8d..bcf0caa95c365 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -17,20 +17,18 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" - - "github.com/coder/coder/v2/coderd/database/db2sdk" - "github.com/coder/coder/v2/coderd/notifications" - "github.com/coder/coder/v2/coderd/rbac/policy" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/util/slice" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionersdk" "github.com/coder/coder/v2/testutil" ) @@ -903,6 +901,14 @@ func (s *MethodTestSuite) TestLicense() { require.NoError(s.T(), err) check.Args().Asserts().Returns("value") })) + s.Run("GetManagedAgentCount", s.Subtest(func(db database.Store, check *expects) { + start := dbtime.Now() + end := start.Add(time.Hour) + check.Args(database.GetManagedAgentCountParams{ + StartTime: start, + EndTime: end, + }).Asserts(rbac.ResourceWorkspace, policy.ActionRead).Returns(int64(0)) + })) } func (s *MethodTestSuite) TestOrganization() { diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index d4e1db1612790..811d945ac7da9 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -964,6 +964,13 @@ func (m queryMetricsStore) GetLogoURL(ctx context.Context) (string, error) { return url, err } +func (m queryMetricsStore) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) { + start := time.Now() + r0, r1 := m.s.GetManagedAgentCount(ctx, arg) + m.queryLatencies.WithLabelValues("GetManagedAgentCount").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { start := time.Now() r0, r1 := m.s.GetNotificationMessagesByStatus(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index f3ed6c2bc78ca..b20c3d06209b5 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2012,6 +2012,21 @@ func (mr *MockStoreMockRecorder) GetLogoURL(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogoURL", reflect.TypeOf((*MockStore)(nil).GetLogoURL), ctx) } +// GetManagedAgentCount mocks base method. +func (m *MockStore) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetManagedAgentCount", ctx, arg) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetManagedAgentCount indicates an expected call of GetManagedAgentCount. +func (mr *MockStoreMockRecorder) GetManagedAgentCount(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedAgentCount", reflect.TypeOf((*MockStore)(nil).GetManagedAgentCount), ctx, arg) +} + // GetNotificationMessagesByStatus mocks base method. func (m *MockStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 6471d79defa6c..baa5d8590b1d7 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -216,6 +216,8 @@ type sqlcQuerier interface { GetLicenseByID(ctx context.Context, id int32) (License, error) GetLicenses(ctx context.Context) ([]License, error) GetLogoURL(ctx context.Context) (string, error) + // This isn't strictly a license query, but it's related to license enforcement. + GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error) GetNotificationMessagesByStatus(ctx context.Context, arg GetNotificationMessagesByStatusParams) ([]NotificationMessage, error) // Fetch the notification report generator log indicating recent activity. GetNotificationReportGeneratorLogByTemplate(ctx context.Context, templateID uuid.UUID) (NotificationReportGeneratorLog, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 47d46a4e74a8b..4bf01000de0ec 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4286,6 +4286,44 @@ func (q *sqlQuerier) GetLicenses(ctx context.Context) ([]License, error) { return items, nil } +const getManagedAgentCount = `-- name: GetManagedAgentCount :one +SELECT + COUNT(DISTINCT wb.id) AS count +FROM + workspace_builds AS wb +JOIN + provisioner_jobs AS pj +ON + wb.job_id = pj.id +WHERE + wb.transition = 'start'::workspace_transition + AND wb.has_ai_task = true + -- Only count jobs that are pending, running or succeeded. Other statuses + -- like cancel(ed|ing), failed or unknown are not considered as managed + -- agent usage. These workspace builds are typically unusable anyway. + AND pj.job_status IN ( + 'pending'::provisioner_job_status, + 'running'::provisioner_job_status, + 'succeeded'::provisioner_job_status + ) + -- Jobs are counted at the time they are created, not when they are + -- completed, as pending jobs haven't completed yet. + AND wb.created_at BETWEEN $1::timestamptz AND $2::timestamptz +` + +type GetManagedAgentCountParams struct { + StartTime time.Time `db:"start_time" json:"start_time"` + EndTime time.Time `db:"end_time" json:"end_time"` +} + +// This isn't strictly a license query, but it's related to license enforcement. +func (q *sqlQuerier) GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error) { + row := q.db.QueryRowContext(ctx, getManagedAgentCount, arg.StartTime, arg.EndTime) + var count int64 + err := row.Scan(&count) + return count, err +} + const getUnexpiredLicenses = `-- name: GetUnexpiredLicenses :many SELECT id, uploaded_at, jwt, exp, uuid FROM licenses diff --git a/coderd/database/queries/licenses.sql b/coderd/database/queries/licenses.sql index 3512a46514787..ac864a94d1792 100644 --- a/coderd/database/queries/licenses.sql +++ b/coderd/database/queries/licenses.sql @@ -35,3 +35,28 @@ DELETE FROM licenses WHERE id = $1 RETURNING id; + +-- name: GetManagedAgentCount :one +-- This isn't strictly a license query, but it's related to license enforcement. +SELECT + COUNT(DISTINCT wb.id) AS count +FROM + workspace_builds AS wb +JOIN + provisioner_jobs AS pj +ON + wb.job_id = pj.id +WHERE + wb.transition = 'start'::workspace_transition + AND wb.has_ai_task = true + -- Only count jobs that are pending, running or succeeded. Other statuses + -- like cancel(ed|ing), failed or unknown are not considered as managed + -- agent usage. These workspace builds are typically unusable anyway. + AND pj.job_status IN ( + 'pending'::provisioner_job_status, + 'running'::provisioner_job_status, + 'succeeded'::provisioner_job_status + ) + -- Jobs are counted at the time they are created, not when they are + -- completed, as pending jobs haven't completed yet. + AND wb.created_at BETWEEN @start_time::timestamptz AND @end_time::timestamptz; diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index 88774c63368ca..884a963405007 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -335,7 +335,7 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) { return } - builder := wsbuilder.New(workspace, database.WorkspaceTransition(createBuild.Transition)). + builder := wsbuilder.New(workspace, database.WorkspaceTransition(createBuild.Transition), *api.BuildUsageChecker.Load()). Initiator(apiKey.UserID). RichParameterValues(createBuild.RichParameterValues). LogLevel(string(createBuild.LogLevel)). diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 32b412946907e..0f3f0a24c75d3 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -701,7 +701,7 @@ func createWorkspace( return xerrors.Errorf("get workspace by ID: %w", err) } - builder := wsbuilder.New(workspace, database.WorkspaceTransitionStart). + builder := wsbuilder.New(workspace, database.WorkspaceTransitionStart, *api.BuildUsageChecker.Load()). Reason(database.BuildReasonInitiator). Initiator(initiatorID). ActiveVersion(). diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go index d608682c58eee..52567b463baac 100644 --- a/coderd/wsbuilder/wsbuilder.go +++ b/coderd/wsbuilder/wsbuilder.go @@ -56,6 +56,7 @@ type Builder struct { logLevel string deploymentValues *codersdk.DeploymentValues experiments codersdk.Experiments + usageChecker UsageChecker richParameterValues []codersdk.WorkspaceBuildParameter initiator uuid.UUID @@ -89,7 +90,24 @@ type Builder struct { verifyNoLegacyParametersOnce bool } -type Option func(Builder) Builder +type UsageChecker interface { + CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (UsageCheckResponse, error) +} + +type UsageCheckResponse struct { + Permitted bool + Message string +} + +type NoopUsageChecker struct{} + +var _ UsageChecker = NoopUsageChecker{} + +func (NoopUsageChecker) CheckBuildUsage(_ context.Context, _ database.Store, _ *database.TemplateVersion) (UsageCheckResponse, error) { + return UsageCheckResponse{ + Permitted: true, + }, nil +} // versionTarget expresses how to determine the template version for the build. // @@ -121,8 +139,8 @@ type stateTarget struct { explicit *[]byte } -func New(w database.Workspace, t database.WorkspaceTransition) Builder { - return Builder{workspace: w, trans: t} +func New(w database.Workspace, t database.WorkspaceTransition, uc UsageChecker) Builder { + return Builder{workspace: w, trans: t, usageChecker: uc} } // Methods that customize the build are public, have a struct receiver and return a new Builder. @@ -321,6 +339,10 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object if err != nil { return nil, nil, nil, err } + err = b.checkUsage() + if err != nil { + return nil, nil, nil, err + } err = b.checkRunningBuild() if err != nil { return nil, nil, nil, err @@ -1253,6 +1275,23 @@ func (b *Builder) checkTemplateJobStatus() error { return nil } +func (b *Builder) checkUsage() error { + templateVersion, err := b.getTemplateVersion() + if err != nil { + return BuildError{http.StatusInternalServerError, "Failed to fetch template version", err} + } + + resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion) + if err != nil { + return BuildError{http.StatusInternalServerError, "Failed to check build usage", err} + } + if !resp.Permitted { + return BuildError{http.StatusForbidden, "Build is not permitted: " + resp.Message, nil} + } + + return nil +} + func (b *Builder) checkRunningBuild() error { job, err := b.getLastBuildJob() if xerrors.Is(err, sql.ErrNoRows) { diff --git a/coderd/wsbuilder/wsbuilder_test.go b/coderd/wsbuilder/wsbuilder_test.go index 41ea3fe2c9921..ee421a8adb649 100644 --- a/coderd/wsbuilder/wsbuilder_test.go +++ b/coderd/wsbuilder/wsbuilder_test.go @@ -5,30 +5,30 @@ import ( "database/sql" "encoding/json" "net/http" + "sync/atomic" "testing" "time" - "github.com/prometheus/client_golang/prometheus" - - "github.com/coder/coder/v2/coderd/coderdtest" - "github.com/coder/coder/v2/coderd/files" - "github.com/coder/coder/v2/coderd/httpapi/httperror" - "github.com/coder/coder/v2/provisionersdk" - "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/propagation" "go.uber.org/mock/gomock" + "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/files" + "github.com/coder/coder/v2/coderd/httpapi/httperror" "github.com/coder/coder/v2/coderd/provisionerdserver" "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/provisionersdk" ) var ( @@ -102,7 +102,7 @@ func TestBuilder_NoOptions(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -142,7 +142,8 @@ func TestBuilder_Initiator(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).Initiator(otherUserID) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + Initiator(otherUserID) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -188,7 +189,8 @@ func TestBuilder_Baggage(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).Initiator(otherUserID) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + Initiator(otherUserID) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{IP: "127.0.0.1"}) req.NoError(err) @@ -227,7 +229,8 @@ func TestBuilder_Reason(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).Reason(database.BuildReasonAutostart) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + Reason(database.BuildReasonAutostart) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -271,7 +274,8 @@ func TestBuilder_ActiveVersion(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).ActiveVersion() + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + ActiveVersion() // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -386,7 +390,8 @@ func TestWorkspaceBuildWithTags(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(buildParameters) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + RichParameterValues(buildParameters) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -469,7 +474,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(nextBuildParameters) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + RichParameterValues(nextBuildParameters) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -517,7 +523,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(nextBuildParameters) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + RichParameterValues(nextBuildParameters) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -555,7 +562,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}) + // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) bldErr := wsbuilder.BuildError{} req.ErrorAs(err, &bldErr) @@ -591,7 +599,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(nextBuildParameters) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + RichParameterValues(nextBuildParameters) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) bldErr := wsbuilder.BuildError{} @@ -656,7 +665,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). RichParameterValues(nextBuildParameters). VersionID(activeVersionID) _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) @@ -720,7 +729,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). RichParameterValues(nextBuildParameters). VersionID(activeVersionID) _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) @@ -782,7 +791,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). RichParameterValues(nextBuildParameters). VersionID(activeVersionID) // nolint: dogsled @@ -849,7 +858,7 @@ func TestWorkspaceBuildWithPreset(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). ActiveVersion(). TemplateVersionPresetID(presetID) // nolint: dogsled @@ -916,7 +925,7 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) { ) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete).Orphan() + uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete, wsbuilder.NoopUsageChecker{}).Orphan() fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) // nolint: dogsled @@ -993,7 +1002,7 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) { ) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete).Orphan() + uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete, wsbuilder.NoopUsageChecker{}).Orphan() fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) @@ -1001,6 +1010,115 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) { }) } +func TestWorkspaceBuildUsageChecker(t *testing.T) { + t.Parallel() + + t.Run("Permitted", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var calls int64 + fakeUsageChecker := &fakeUsageChecker{ + checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + atomic.AddInt64(&calls, 1) + return wsbuilder.UsageCheckResponse{Permitted: true}, nil + }, + } + + mDB := expectDB(t, + // Inputs + withTemplate, + withInactiveVersion(nil), + withLastBuildFound, + withTemplateVersionVariables(inactiveVersionID, nil), + withRichParameters(nil), + withParameterSchemas(inactiveJobID, nil), + withWorkspaceTags(inactiveVersionID, nil), + withProvisionerDaemons([]database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow{}), + + // Outputs + expectProvisionerJob(func(job database.InsertProvisionerJobParams) {}), + withInTx, + expectBuild(func(bld database.InsertWorkspaceBuildParams) {}), + withBuild, + expectBuildParameters(func(params database.InsertWorkspaceBuildParametersParams) {}), + ) + fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) + + ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, fakeUsageChecker) + // nolint: dogsled + _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) + require.NoError(t, err) + require.EqualValues(t, 1, calls) + }) + + // The failure cases are mostly identical from a test perspective. + const message = "fake test message" + cases := []struct { + name string + response wsbuilder.UsageCheckResponse + responseErr error + assertions func(t *testing.T, err error) + }{ + { + name: "NotPermitted", + response: wsbuilder.UsageCheckResponse{ + Permitted: false, + Message: message, + }, + assertions: func(t *testing.T, err error) { + require.ErrorContains(t, err, message) + var buildErr wsbuilder.BuildError + require.ErrorAs(t, err, &buildErr) + require.Equal(t, http.StatusForbidden, buildErr.Status) + }, + }, + { + name: "Error", + responseErr: xerrors.New("fake error"), + assertions: func(t *testing.T, err error) { + require.ErrorContains(t, err, "fake error") + require.ErrorAs(t, err, &wsbuilder.BuildError{}) + }, + }, + } + + for _, c := range cases { + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var calls int64 + fakeUsageChecker := &fakeUsageChecker{ + checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + atomic.AddInt64(&calls, 1) + return c.response, c.responseErr + }, + } + + mDB := expectDB(t, + withTemplate, + withInactiveVersionNoParams(), + ) + fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) + + ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, fakeUsageChecker). + VersionID(inactiveVersionID) + // nolint: dogsled + _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) + c.assertions(t, err) + require.EqualValues(t, 1, calls) + }) + } +} + func TestWsbuildError(t *testing.T) { t.Parallel() @@ -1366,3 +1484,11 @@ func withProvisionerDaemons(provisionerDaemons []database.GetEligibleProvisioner mTx.EXPECT().GetEligibleProvisionerDaemonsByProvisionerJobIDs(gomock.Any(), gomock.Any()).Return(provisionerDaemons, nil) } } + +type fakeUsageChecker struct { + checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) +} + +func (f *fakeUsageChecker) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + return f.checkBuildUsageFunc(ctx, store, templateVersion) +} diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 0d176567713a2..d6e47f4cfdf00 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -22,6 +22,7 @@ import ( agplportsharing "github.com/coder/coder/v2/coderd/portsharing" agplprebuilds "github.com/coder/coder/v2/coderd/prebuilds" "github.com/coder/coder/v2/coderd/rbac/policy" + "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/coder/v2/enterprise/coderd/connectionlog" "github.com/coder/coder/v2/enterprise/coderd/enidpsync" "github.com/coder/coder/v2/enterprise/coderd/portsharing" @@ -916,10 +917,70 @@ func (api *API) updateEntitlements(ctx context.Context) error { reloadedEntitlements.Warnings = append(reloadedEntitlements.Warnings, msg) } reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption + + // If there's a license installed, we will use the enterprise build + // limit checker. + // This checker currently only enforces the managed agent limit. + if reloadedEntitlements.HasLicense { + var checker wsbuilder.UsageChecker = api + api.AGPL.BuildUsageChecker.Store(&checker) + } else { + // Don't check any usage, just like AGPL. + var checker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{} + api.AGPL.BuildUsageChecker.Store(&checker) + } + return reloadedEntitlements, nil }) } +var _ wsbuilder.UsageChecker = &API{} + +func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + // We assume that if this function is called, a valid license is installed. + // When there are no licenses installed, a noop usage checker is used + // instead. + + // If the template version doesn't have an AI task, we don't need to check + // usage. + if !templateVersion.HasAITask.Valid || !templateVersion.HasAITask.Bool { + return wsbuilder.UsageCheckResponse{ + Permitted: true, + }, nil + } + + // Otherwise, we need to check that we haven't breached the managed agent + // limit. + managedAgentLimit, ok := api.Entitlements.Feature(codersdk.FeatureManagedAgentLimit) + if !ok || !managedAgentLimit.Enabled || managedAgentLimit.Limit == nil || managedAgentLimit.UsagePeriod == nil { + return wsbuilder.UsageCheckResponse{ + Permitted: false, + Message: "Your license is not entitled to managed agents. Please contact sales to continue using managed agents.", + }, nil + } + + // This check is intentionally not committed to the database. It's fine if + // it's not 100% accurate or allows for minor breaches due to build races. + managedAgentCount, err := store.GetManagedAgentCount(ctx, database.GetManagedAgentCountParams{ + StartTime: managedAgentLimit.UsagePeriod.Start, + EndTime: managedAgentLimit.UsagePeriod.End, + }) + if err != nil { + return wsbuilder.UsageCheckResponse{}, xerrors.Errorf("get managed agent count: %w", err) + } + + if managedAgentCount >= *managedAgentLimit.Limit { + return wsbuilder.UsageCheckResponse{ + Permitted: false, + Message: "You have breached the managed agent limit in your license. Please contact sales to continue using managed agents.", + }, nil + } + + return wsbuilder.UsageCheckResponse{ + Permitted: true, + }, nil +} + // getProxyDERPStartingRegionID returns the starting region ID that should be // used for workspace proxies. A proxy's actual region ID is the return value // from this function + it's RegionID field. @@ -1186,6 +1247,6 @@ func (api *API) setupPrebuilds(featureEnabled bool) (agplprebuilds.Reconciliatio } reconciler := prebuilds.NewStoreReconciler(api.Database, api.Pubsub, api.AGPL.FileCache, api.DeploymentValues.Prebuilds, - api.Logger.Named("prebuilds"), quartz.NewReal(), api.PrometheusRegistry, api.NotificationsEnqueuer) + api.Logger.Named("prebuilds"), quartz.NewReal(), api.PrometheusRegistry, api.NotificationsEnqueuer, api.AGPL.BuildUsageChecker) return reconciler, prebuilds.NewEnterpriseClaimer(api.Database) } diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 52301f6dae034..42645a98b06c2 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -32,6 +32,8 @@ import ( "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/enterprise/coderd/prebuilds" + "github.com/coder/coder/v2/provisioner/echo" + "github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/retry" @@ -621,6 +623,88 @@ func TestSCIMDisabled(t *testing.T) { } } +func TestManagedAgentLimit(t *testing.T) { + t.Parallel() + + cli, _ := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + }, + LicenseOptions: (&coderdenttest.LicenseOptions{}).ManagedAgentLimit(1, 1), + }) + + // It's fine that the app ID is only used in a single successful workspace + // build. + appID := uuid.NewString() + echoRes := &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: []*proto.Response{ + { + Type: &proto.Response_Plan{ + Plan: &proto.PlanComplete{ + Plan: []byte("{}"), + ModuleFiles: []byte{}, + HasAiTasks: true, + }, + }, + }, + }, + ProvisionApply: []*proto.Response{{ + Type: &proto.Response_Apply{ + Apply: &proto.ApplyComplete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Name: "example", + Auth: &proto.Agent_Token{ + Token: uuid.NewString(), + }, + Apps: []*proto.App{{ + Id: appID, + Slug: "test", + Url: "http://localhost:1234", + }}, + }}, + }}, + AiTasks: []*proto.AITask{{ + Id: uuid.NewString(), + SidebarApp: &proto.AITaskSidebarApp{ + Id: appID, + }, + }}, + }, + }, + }}, + } + + // Create two templates, one with AI and one without. + aiVersion := coderdtest.CreateTemplateVersion(t, cli, uuid.Nil, echoRes) + coderdtest.AwaitTemplateVersionJobCompleted(t, cli, aiVersion.ID) + aiTemplate := coderdtest.CreateTemplate(t, cli, uuid.Nil, aiVersion.ID) + noAiVersion := coderdtest.CreateTemplateVersion(t, cli, uuid.Nil, nil) // use default responses + coderdtest.AwaitTemplateVersionJobCompleted(t, cli, noAiVersion.ID) + noAiTemplate := coderdtest.CreateTemplate(t, cli, uuid.Nil, noAiVersion.ID) + + // Create one AI workspace, which should succeed. + workspace := coderdtest.CreateWorkspace(t, cli, aiTemplate.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, cli, workspace.LatestBuild.ID) + + // Create a second AI workspace, which should fail. This needs to be done + // manually because coderdtest.CreateWorkspace expects it to succeed. + _, err := cli.CreateUserWorkspace(context.Background(), codersdk.Me, codersdk.CreateWorkspaceRequest{ //nolint:gocritic // owners must still be subject to the limit + TemplateID: aiTemplate.ID, + Name: coderdtest.RandomUsername(t), + AutomaticUpdates: codersdk.AutomaticUpdatesNever, + }) + require.ErrorContains(t, err, "You have breached the managed agent limit in your license") + + // Create a third non-AI workspace, which should succeed. + workspace = coderdtest.CreateWorkspace(t, cli, noAiTemplate.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, cli, workspace.LatestBuild.ID) +} + // testDBAuthzRole returns a context with a subject that has a role // with permissions required for test setup. func testDBAuthzRole(ctx context.Context) context.Context { diff --git a/enterprise/coderd/license/license.go b/enterprise/coderd/license/license.go index 9371c10c138d8..7776557522f86 100644 --- a/enterprise/coderd/license/license.go +++ b/enterprise/coderd/license/license.go @@ -94,15 +94,15 @@ func Entitlements( return codersdk.Entitlements{}, xerrors.Errorf("query active user count: %w", err) } - // always shows active user count regardless of license entitlements, err := LicensesEntitlements(ctx, now, licenses, enablements, keys, FeatureArguments{ ActiveUserCount: activeUserCount, ReplicaCount: replicaCount, ExternalAuthCount: externalAuthCount, - ManagedAgentCountFn: func(_ context.Context, _ time.Time, _ time.Time) (int64, error) { - // TODO(@deansheather): replace this with a real implementation in a - // follow up PR. - return 0, nil + ManagedAgentCountFn: func(ctx context.Context, startTime time.Time, endTime time.Time) (int64, error) { + return db.GetManagedAgentCount(ctx, database.GetManagedAgentCountParams{ + StartTime: startTime, + EndTime: endTime, + }) }, }) if err != nil { diff --git a/enterprise/coderd/license/license_test.go b/enterprise/coderd/license/license_test.go index fac1d2b44bb63..d8203117039cb 100644 --- a/enterprise/coderd/license/license_test.go +++ b/enterprise/coderd/license/license_test.go @@ -10,8 +10,10 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/codersdk" @@ -678,6 +680,67 @@ func TestEntitlements(t *testing.T) { require.Len(t, entitlements.Warnings, 1) require.Equal(t, "You have multiple External Auth Providers configured but your license is expired. Reduce to one.", entitlements.Warnings[0]) }) + + t.Run("ManagedAgentLimitHasValue", func(t *testing.T) { + t.Parallel() + + // Use a mock database for this test so I don't need to make real + // workspace builds. + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + + licenseOpts := (&coderdenttest.LicenseOptions{ + FeatureSet: codersdk.FeatureSetPremium, + IssuedAt: dbtime.Now().Add(-2 * time.Hour).Truncate(time.Second), + NotBefore: dbtime.Now().Add(-time.Hour).Truncate(time.Second), + GraceAt: dbtime.Now().Add(time.Hour * 24 * 60).Truncate(time.Second), // 60 days to remove warning + ExpiresAt: dbtime.Now().Add(time.Hour * 24 * 90).Truncate(time.Second), // 90 days to remove warning + }). + UserLimit(100). + ManagedAgentLimit(100, 200) + + lic := database.License{ + ID: 1, + JWT: coderdenttest.GenerateLicense(t, *licenseOpts), + Exp: licenseOpts.ExpiresAt, + } + + mDB.EXPECT(). + GetUnexpiredLicenses(gomock.Any()). + Return([]database.License{lic}, nil) + mDB.EXPECT(). + GetActiveUserCount(gomock.Any(), false). + Return(int64(1), nil) + mDB.EXPECT(). + GetManagedAgentCount(gomock.Any(), gomock.Cond(func(params database.GetManagedAgentCountParams) bool { + // gomock doesn't seem to compare times very nicely. + if !assert.WithinDuration(t, licenseOpts.NotBefore, params.StartTime, time.Second) { + return false + } + if !assert.WithinDuration(t, licenseOpts.ExpiresAt, params.EndTime, time.Second) { + return false + } + return true + })). + Return(int64(175), nil) + + entitlements, err := license.Entitlements(context.Background(), mDB, 1, 0, coderdenttest.Keys, all) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + + managedAgentLimit, ok := entitlements.Features[codersdk.FeatureManagedAgentLimit] + require.True(t, ok) + require.NotNil(t, managedAgentLimit.SoftLimit) + require.EqualValues(t, 100, *managedAgentLimit.SoftLimit) + require.NotNil(t, managedAgentLimit.Limit) + require.EqualValues(t, 200, *managedAgentLimit.Limit) + require.NotNil(t, managedAgentLimit.Actual) + require.EqualValues(t, 175, *managedAgentLimit.Actual) + + // Should've also populated a warning. + require.Len(t, entitlements.Warnings, 1) + require.Equal(t, "You are approaching the managed agent limit in your license. Please refer to the Deployment Licenses page for more information.", entitlements.Warnings[0]) + }) } func TestLicenseEntitlements(t *testing.T) { diff --git a/enterprise/coderd/prebuilds/claim_test.go b/enterprise/coderd/prebuilds/claim_test.go index 67c1f0dd21ade..01195e3485016 100644 --- a/enterprise/coderd/prebuilds/claim_test.go +++ b/enterprise/coderd/prebuilds/claim_test.go @@ -166,7 +166,7 @@ func TestClaimPrebuild(t *testing.T) { defer provisionerCloser.Close() cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(spy, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(spy, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(spy) api.AGPL.PrebuildsClaimer.Store(&claimer) diff --git a/enterprise/coderd/prebuilds/metricscollector_test.go b/enterprise/coderd/prebuilds/metricscollector_test.go index 96c3d071ac48a..1e9f3f5082806 100644 --- a/enterprise/coderd/prebuilds/metricscollector_test.go +++ b/enterprise/coderd/prebuilds/metricscollector_test.go @@ -201,7 +201,7 @@ func TestMetricsCollector(t *testing.T) { clock := quartz.NewMock(t) db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) createdUsers := []uuid.UUID{database.PrebuildsSystemUserID} @@ -338,7 +338,7 @@ func TestMetricsCollector_DuplicateTemplateNames(t *testing.T) { clock := quartz.NewMock(t) db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) collector := prebuilds.NewMetricsCollector(db, logger, reconciler) @@ -491,7 +491,7 @@ func TestMetricsCollector_ReconciliationPausedMetric(t *testing.T) { db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) registry := prometheus.NewPedanticRegistry() - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) // Ensure no pause setting is set (default state) @@ -520,7 +520,7 @@ func TestMetricsCollector_ReconciliationPausedMetric(t *testing.T) { db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) registry := prometheus.NewPedanticRegistry() - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) // Set reconciliation to paused @@ -549,7 +549,7 @@ func TestMetricsCollector_ReconciliationPausedMetric(t *testing.T) { db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) registry := prometheus.NewPedanticRegistry() - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) // Set reconciliation back to not paused diff --git a/enterprise/coderd/prebuilds/reconcile.go b/enterprise/coderd/prebuilds/reconcile.go index 049568c7e7f0c..214d1643bb228 100644 --- a/enterprise/coderd/prebuilds/reconcile.go +++ b/enterprise/coderd/prebuilds/reconcile.go @@ -39,15 +39,16 @@ import ( ) type StoreReconciler struct { - store database.Store - cfg codersdk.PrebuildsConfig - pubsub pubsub.Pubsub - fileCache *files.Cache - logger slog.Logger - clock quartz.Clock - registerer prometheus.Registerer - metrics *MetricsCollector - notifEnq notifications.Enqueuer + store database.Store + cfg codersdk.PrebuildsConfig + pubsub pubsub.Pubsub + fileCache *files.Cache + logger slog.Logger + clock quartz.Clock + registerer prometheus.Registerer + metrics *MetricsCollector + notifEnq notifications.Enqueuer + buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker] cancelFn context.CancelCauseFunc running atomic.Bool @@ -66,6 +67,7 @@ func NewStoreReconciler(store database.Store, clock quartz.Clock, registerer prometheus.Registerer, notifEnq notifications.Enqueuer, + buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker], ) *StoreReconciler { reconciler := &StoreReconciler{ store: store, @@ -76,6 +78,7 @@ func NewStoreReconciler(store database.Store, clock: clock, registerer: registerer, notifEnq: notifEnq, + buildUsageChecker: buildUsageChecker, done: make(chan struct{}, 1), provisionNotifyCh: make(chan database.ProvisionerJob, 10), } @@ -738,7 +741,7 @@ func (c *StoreReconciler) provision( }) } - builder := wsbuilder.New(workspace, transition). + builder := wsbuilder.New(workspace, transition, *c.buildUsageChecker.Load()). Reason(database.BuildReasonInitiator). Initiator(database.PrebuildsSystemUserID). MarkPrebuild() diff --git a/enterprise/coderd/prebuilds/reconcile_test.go b/enterprise/coderd/prebuilds/reconcile_test.go index 5ba36912ce5c8..8d2a81e1ade83 100644 --- a/enterprise/coderd/prebuilds/reconcile_test.go +++ b/enterprise/coderd/prebuilds/reconcile_test.go @@ -6,6 +6,7 @@ import ( "fmt" "sort" "sync" + "sync/atomic" "testing" "time" @@ -19,6 +20,7 @@ import ( "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/notifications/notificationstest" "github.com/coder/coder/v2/coderd/util/slice" + "github.com/coder/coder/v2/coderd/wsbuilder" sdkproto "github.com/coder/coder/v2/provisionersdk/proto" "github.com/google/uuid" @@ -56,7 +58,7 @@ func TestNoReconciliationActionsIfNoPresets(t *testing.T) { } logger := testutil.Logger(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // given a template version with no presets org := dbgen.Organization(t, db, database.Organization{}) @@ -102,7 +104,7 @@ func TestNoReconciliationActionsIfNoPrebuilds(t *testing.T) { } logger := testutil.Logger(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // given there are presets, but no prebuilds org := dbgen.Organization(t, db, database.Organization{}) @@ -382,7 +384,7 @@ func TestPrebuildReconciliation(t *testing.T) { pubSub = &brokenPublisher{Pubsub: pubSub} } cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // Run the reconciliation multiple times to ensure idempotency // 8 was arbitrary, but large enough to reasonably trust the result @@ -460,7 +462,7 @@ func TestMultiplePresetsPerTemplateVersion(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -586,7 +588,7 @@ func TestPrebuildScheduling(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -691,7 +693,7 @@ func TestInvalidPreset(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -756,7 +758,7 @@ func TestDeletionOfPrebuiltWorkspaceWithInvalidPreset(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -853,7 +855,7 @@ func TestSkippingHardLimitedPresets(t *testing.T) { fakeEnqueuer := newFakeEnqueuer() registry := prometheus.NewRegistry() cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr()) // Set up test environment with a template, version, and preset. ownerID := uuid.New() @@ -997,7 +999,7 @@ func TestHardLimitedPresetShouldNotBlockDeletion(t *testing.T) { fakeEnqueuer := newFakeEnqueuer() registry := prometheus.NewRegistry() cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr()) // Set up test environment with a template, version, and preset. ownerID := uuid.New() @@ -1191,7 +1193,7 @@ func TestRunLoop(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -1322,7 +1324,7 @@ func TestFailedBuildBackoff(t *testing.T) { ).Leveled(slog.LevelDebug) db, ps := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // Given: an active template version with presets and prebuilds configured. const desiredInstances = 2 @@ -1447,7 +1449,8 @@ func TestReconciliationLock(t *testing.T) { slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug), quartz.NewMock(t), prometheus.NewRegistry(), - newNoopEnqueuer()) + newNoopEnqueuer(), + newNoopUsageCheckerPtr()) reconciler.WithReconciliationLock(ctx, logger, func(_ context.Context, _ database.Store) error { lockObtained := mutex.TryLock() // As long as the postgres lock is held, this mutex should always be unlocked when we get here. @@ -1481,7 +1484,7 @@ func TestTrackResourceReplacement(t *testing.T) { fakeEnqueuer := newFakeEnqueuer() registry := prometheus.NewRegistry() cache := files.New(registry, &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, ps, cache, codersdk.PrebuildsConfig{}, logger, clock, registry, fakeEnqueuer) + reconciler := prebuilds.NewStoreReconciler(db, ps, cache, codersdk.PrebuildsConfig{}, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr()) // Given: a template admin to receive a notification. templateAdmin := dbgen.User(t, db, database.User{ @@ -1637,7 +1640,7 @@ func TestExpiredPrebuildsMultipleActions(t *testing.T) { fakeEnqueuer := newFakeEnqueuer() registry := prometheus.NewRegistry() cache := files.New(registry, &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr()) // Set up test environment with a template, version, and preset ownerID := uuid.New() @@ -1800,6 +1803,13 @@ func newFakeEnqueuer() *notificationstest.FakeEnqueuer { return notificationstest.NewFakeEnqueuer() } +func newNoopUsageCheckerPtr() *atomic.Pointer[wsbuilder.UsageChecker] { + var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{} + buildUsageChecker := atomic.Pointer[wsbuilder.UsageChecker]{} + buildUsageChecker.Store(&noopUsageChecker) + return &buildUsageChecker +} + // nolint:revive // It's a control flag, but this is a test. func setupTestDBTemplate( t *testing.T, @@ -2270,7 +2280,7 @@ func TestReconciliationRespectsPauseSetting(t *testing.T) { } logger := testutil.Logger(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // Setup a template with a preset that should create prebuilds org := dbgen.Organization(t, db, database.Organization{}) diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go index d622748899aa0..2278fb2a71939 100644 --- a/enterprise/coderd/workspaces_test.go +++ b/enterprise/coderd/workspaces_test.go @@ -1864,6 +1864,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer) @@ -2004,6 +2005,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer) @@ -2134,6 +2136,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer) @@ -2266,6 +2269,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer) @@ -2376,6 +2380,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer)