Skip to content

Commit b199eb1

Browse files
authored
fix: allow stops and deletes after breaching AI limit (#21186)
Fixes a bug a customer encountered once they breached their limit. Adds a test.
1 parent 97bc7eb commit b199eb1

File tree

4 files changed

+108
-26
lines changed

4 files changed

+108
-26
lines changed

coderd/wsbuilder/wsbuilder.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ type Builder struct {
9393
}
9494

9595
type UsageChecker interface {
96-
CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (UsageCheckResponse, error)
96+
CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (UsageCheckResponse, error)
9797
}
9898

9999
type UsageCheckResponse struct {
@@ -105,7 +105,7 @@ type NoopUsageChecker struct{}
105105

106106
var _ UsageChecker = NoopUsageChecker{}
107107

108-
func (NoopUsageChecker) CheckBuildUsage(_ context.Context, _ database.Store, _ *database.TemplateVersion) (UsageCheckResponse, error) {
108+
func (NoopUsageChecker) CheckBuildUsage(_ context.Context, _ database.Store, _ *database.TemplateVersion, _ database.WorkspaceTransition) (UsageCheckResponse, error) {
109109
return UsageCheckResponse{
110110
Permitted: true,
111111
}, nil
@@ -1307,7 +1307,7 @@ func (b *Builder) checkUsage() error {
13071307
return BuildError{http.StatusInternalServerError, "Failed to fetch template version", err}
13081308
}
13091309

1310-
resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion)
1310+
resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion, b.trans)
13111311
if err != nil {
13121312
return BuildError{http.StatusInternalServerError, "Failed to check build usage", err}
13131313
}

coderd/wsbuilder/wsbuilder_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) {
10491049

10501050
var calls int64
10511051
fakeUsageChecker := &fakeUsageChecker{
1052-
checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) {
1052+
checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion, _ database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) {
10531053
atomic.AddInt64(&calls, 1)
10541054
return wsbuilder.UsageCheckResponse{Permitted: true}, nil
10551055
},
@@ -1126,7 +1126,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) {
11261126

11271127
var calls int64
11281128
fakeUsageChecker := &fakeUsageChecker{
1129-
checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) {
1129+
checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion, _ database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) {
11301130
atomic.AddInt64(&calls, 1)
11311131
return c.response, c.responseErr
11321132
},
@@ -1577,11 +1577,11 @@ func expectFindMatchingPresetID(id uuid.UUID, err error) func(mTx *dbmock.MockSt
15771577
}
15781578

15791579
type fakeUsageChecker struct {
1580-
checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error)
1580+
checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error)
15811581
}
15821582

1583-
func (f *fakeUsageChecker) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) {
1584-
return f.checkBuildUsageFunc(ctx, store, templateVersion)
1583+
func (f *fakeUsageChecker) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) {
1584+
return f.checkBuildUsageFunc(ctx, store, templateVersion, transition)
15851585
}
15861586

15871587
func withNoTask(mTx *dbmock.MockStore) {

enterprise/coderd/coderd.go

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ func (api *API) updateEntitlements(ctx context.Context) error {
971971

972972
var _ wsbuilder.UsageChecker = &API{}
973973

974-
func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) {
974+
func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) {
975975
// If the template version has an external agent, we need to check that the
976976
// license is entitled to this feature.
977977
if templateVersion.HasExternalAgent.Valid && templateVersion.HasExternalAgent.Bool {
@@ -984,16 +984,31 @@ func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templ
984984
}
985985
}
986986

987-
// If the template version doesn't have an AI task, we don't need to check
988-
// usage.
987+
resp, err := api.checkAIBuildUsage(ctx, store, templateVersion, transition)
988+
if err != nil {
989+
return wsbuilder.UsageCheckResponse{}, err
990+
}
991+
if !resp.Permitted {
992+
return resp, nil
993+
}
994+
995+
return wsbuilder.UsageCheckResponse{Permitted: true}, nil
996+
}
997+
998+
// checkAIBuildUsage validates AI-related usage constraints. It is a no-op
999+
// unless the transition is "start" and the template version has an AI task.
1000+
func (api *API) checkAIBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) {
1001+
// Only check AI usage rules for start transitions.
1002+
if transition != database.WorkspaceTransitionStart {
1003+
return wsbuilder.UsageCheckResponse{Permitted: true}, nil
1004+
}
1005+
1006+
// If the template version doesn't have an AI task, we don't need to check usage.
9891007
if !templateVersion.HasAITask.Valid || !templateVersion.HasAITask.Bool {
990-
return wsbuilder.UsageCheckResponse{
991-
Permitted: true,
992-
}, nil
1008+
return wsbuilder.UsageCheckResponse{Permitted: true}, nil
9931009
}
9941010

995-
// When unlicensed, we need to check that we haven't breached the managed agent
996-
// limit.
1011+
// When licensed, ensure we haven't breached the managed agent limit.
9971012
// Unlicensed deployments are allowed to use unlimited managed agents.
9981013
if api.Entitlements.HasLicense() {
9991014
managedAgentLimit, ok := api.Entitlements.Feature(codersdk.FeatureManagedAgentLimit)
@@ -1004,8 +1019,9 @@ func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templ
10041019
}, nil
10051020
}
10061021

1007-
// This check is intentionally not committed to the database. It's fine if
1008-
// it's not 100% accurate or allows for minor breaches due to build races.
1022+
// This check is intentionally not committed to the database. It's fine
1023+
// if it's not 100% accurate or allows for minor breaches due to build
1024+
// races.
10091025
// nolint:gocritic // Requires permission to read all usage events.
10101026
managedAgentCount, err := store.GetTotalUsageDCManagedAgentsV1(agpldbauthz.AsSystemRestricted(ctx), database.GetTotalUsageDCManagedAgentsV1Params{
10111027
StartDate: managedAgentLimit.UsagePeriod.Start,
@@ -1023,9 +1039,7 @@ func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templ
10231039
}
10241040
}
10251041

1026-
return wsbuilder.UsageCheckResponse{
1027-
Permitted: true,
1028-
}, nil
1042+
return wsbuilder.UsageCheckResponse{Permitted: true}, nil
10291043
}
10301044

10311045
// getProxyDERPStartingRegionID returns the starting region ID that should be

enterprise/coderd/coderd_test.go

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package coderd_test
33
import (
44
"bytes"
55
"context"
6+
"database/sql"
67
"encoding/json"
78
"fmt"
89
"io"
@@ -21,6 +22,7 @@ import (
2122
"github.com/stretchr/testify/assert"
2223
"github.com/stretchr/testify/require"
2324
"go.uber.org/goleak"
25+
"go.uber.org/mock/gomock"
2426

2527
"cdr.dev/slog"
2628
"cdr.dev/slog/sloggers/slogtest"
@@ -39,13 +41,16 @@ import (
3941
"github.com/coder/retry"
4042
"github.com/coder/serpent"
4143

44+
agplcoderd "github.com/coder/coder/v2/coderd"
4245
agplaudit "github.com/coder/coder/v2/coderd/audit"
4346
"github.com/coder/coder/v2/coderd/coderdtest"
4447
"github.com/coder/coder/v2/coderd/database"
4548
"github.com/coder/coder/v2/coderd/database/dbauthz"
4649
"github.com/coder/coder/v2/coderd/database/dbfake"
50+
"github.com/coder/coder/v2/coderd/database/dbmock"
4751
"github.com/coder/coder/v2/coderd/database/dbtestutil"
4852
"github.com/coder/coder/v2/coderd/database/dbtime"
53+
"github.com/coder/coder/v2/coderd/entitlements"
4954
"github.com/coder/coder/v2/coderd/rbac"
5055
"github.com/coder/coder/v2/codersdk"
5156
"github.com/coder/coder/v2/codersdk/workspacesdk"
@@ -635,18 +640,18 @@ func TestManagedAgentLimit(t *testing.T) {
635640
})
636641

637642
// Get entitlements to check that the license is a-ok.
638-
entitlements, err := cli.Entitlements(ctx) //nolint:gocritic // we're not testing authz on the entitlements endpoint, so using owner is fine
643+
sdkEntitlements, err := cli.Entitlements(ctx) //nolint:gocritic // we're not testing authz on the entitlements endpoint, so using owner is fine
639644
require.NoError(t, err)
640-
require.True(t, entitlements.HasLicense)
641-
agentLimit := entitlements.Features[codersdk.FeatureManagedAgentLimit]
645+
require.True(t, sdkEntitlements.HasLicense)
646+
agentLimit := sdkEntitlements.Features[codersdk.FeatureManagedAgentLimit]
642647
require.True(t, agentLimit.Enabled)
643648
require.NotNil(t, agentLimit.Limit)
644649
require.EqualValues(t, 1, *agentLimit.Limit)
645650
require.NotNil(t, agentLimit.SoftLimit)
646651
require.EqualValues(t, 1, *agentLimit.SoftLimit)
647-
require.Empty(t, entitlements.Errors)
652+
require.Empty(t, sdkEntitlements.Errors)
648653
// There should be a warning since we're really close to our agent limit.
649-
require.Equal(t, entitlements.Warnings[0], "You are approaching the managed agent limit in your license. Please refer to the Deployment Licenses page for more information.")
654+
require.Equal(t, sdkEntitlements.Warnings[0], "You are approaching the managed agent limit in your license. Please refer to the Deployment Licenses page for more information.")
650655

651656
// Create a fake provision response that claims there are agents in the
652657
// template and every built workspace.
@@ -723,6 +728,69 @@ func TestManagedAgentLimit(t *testing.T) {
723728
coderdtest.AwaitWorkspaceBuildJobCompleted(t, cli, workspace.LatestBuild.ID)
724729
}
725730

731+
func TestCheckBuildUsage_SkipsAIForNonStartTransitions(t *testing.T) {
732+
t.Parallel()
733+
ctrl := gomock.NewController(t)
734+
defer ctrl.Finish()
735+
736+
// Prepare entitlements with a managed agent limit to enforce.
737+
entSet := entitlements.New()
738+
entSet.Modify(func(e *codersdk.Entitlements) {
739+
e.HasLicense = true
740+
limit := int64(1)
741+
issuedAt := time.Now().Add(-2 * time.Hour)
742+
start := time.Now().Add(-time.Hour)
743+
end := time.Now().Add(time.Hour)
744+
e.Features[codersdk.FeatureManagedAgentLimit] = codersdk.Feature{
745+
Enabled: true,
746+
Limit: &limit,
747+
UsagePeriod: &codersdk.UsagePeriod{IssuedAt: issuedAt, Start: start, End: end},
748+
}
749+
})
750+
751+
// Enterprise API instance with entitlements injected.
752+
agpl := &agplcoderd.API{
753+
Options: &agplcoderd.Options{
754+
Entitlements: entSet,
755+
},
756+
}
757+
eapi := &coderd.API{
758+
AGPL: agpl,
759+
Options: &coderd.Options{Options: agpl.Options},
760+
}
761+
762+
// Template version that has an AI task.
763+
tv := &database.TemplateVersion{
764+
HasAITask: sql.NullBool{Valid: true, Bool: true},
765+
HasExternalAgent: sql.NullBool{Valid: true, Bool: false},
766+
}
767+
768+
// Mock DB: expect exactly one count call for the "start" transition.
769+
mDB := dbmock.NewMockStore(ctrl)
770+
mDB.EXPECT().
771+
GetTotalUsageDCManagedAgentsV1(gomock.Any(), gomock.Any()).
772+
Times(1).
773+
Return(int64(1), nil) // equal to limit -> should breach
774+
775+
ctx := context.Background()
776+
777+
// Start transition: should be not permitted due to limit breach.
778+
startResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, database.WorkspaceTransitionStart)
779+
require.NoError(t, err)
780+
require.False(t, startResp.Permitted)
781+
require.Contains(t, startResp.Message, "breached the managed agent limit")
782+
783+
// Stop transition: should be permitted and must not trigger additional DB calls.
784+
stopResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, database.WorkspaceTransitionStop)
785+
require.NoError(t, err)
786+
require.True(t, stopResp.Permitted)
787+
788+
// Delete transition: should be permitted and must not trigger additional DB calls.
789+
deleteResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, database.WorkspaceTransitionDelete)
790+
require.NoError(t, err)
791+
require.True(t, deleteResp.Permitted)
792+
}
793+
726794
// testDBAuthzRole returns a context with a subject that has a role
727795
// with permissions required for test setup.
728796
func testDBAuthzRole(ctx context.Context) context.Context {

0 commit comments

Comments
 (0)