From 042a5e2e960762a32125f4d2d9d7ba3f18fb71c4 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Wed, 30 Jul 2025 14:34:26 +0000 Subject: [PATCH] chore: move usage types to new package --- coderd/database/modelmethods.go | 9 -- .../provisionerdserver/provisionerdserver.go | 15 +-- .../provisionerdserver_test.go | 7 +- coderd/usage/collector.go | 5 +- coderd/usage/events.go | 47 ------- coderd/usage/usagetypes/events.go | 119 ++++++++++++++++++ coderd/usage/usagetypes/events_test.go | 59 +++++++++ coderd/usage/usagetypes/tallyman.go | 70 +++++++++++ coderd/usage/usagetypes/tallyman_test.go | 85 +++++++++++++ enterprise/coderd/usage/collector.go | 10 +- enterprise/coderd/usage/collector_test.go | 22 ++-- enterprise/coderd/usage/publisher.go | 78 ++++-------- enterprise/coderd/usage/publisher_test.go | 115 +++++++++-------- 13 files changed, 453 insertions(+), 188 deletions(-) delete mode 100644 coderd/usage/events.go create mode 100644 coderd/usage/usagetypes/events.go create mode 100644 coderd/usage/usagetypes/events_test.go create mode 100644 coderd/usage/usagetypes/tallyman.go create mode 100644 coderd/usage/usagetypes/tallyman_test.go diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index b326890d0f184..b49fa113d4b12 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -4,7 +4,6 @@ import ( "encoding/hex" "sort" "strconv" - "strings" "time" "github.com/google/uuid" @@ -629,11 +628,3 @@ func (m WorkspaceAgentVolumeResourceMonitor) Debounce( return m.DebouncedUntil, false } - -func (e UsageEventType) IsDiscrete() bool { - return e.Valid() && strings.HasPrefix(string(e), "dc_") -} - -func (e UsageEventType) IsHeartbeat() bool { - return e.Valid() && strings.HasPrefix(string(e), "hb_") -} diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index ef34dd1719be7..d63d47cf4f634 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -28,14 +28,6 @@ import ( protobuf "google.golang.org/protobuf/proto" "cdr.dev/slog" - - "github.com/coder/coder/v2/coderd/usage" - "github.com/coder/coder/v2/coderd/util/slice" - - "github.com/coder/coder/v2/codersdk/drpcsdk" - - "github.com/coder/quartz" - "github.com/coder/coder/v2/coderd/apikey" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/database" @@ -49,13 +41,18 @@ import ( "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/coderd/usage" + "github.com/coder/coder/v2/coderd/usage/usagetypes" + "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" + "github.com/coder/coder/v2/codersdk/drpcsdk" "github.com/coder/coder/v2/provisioner" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionersdk" sdkproto "github.com/coder/coder/v2/provisionersdk/proto" + "github.com/coder/quartz" ) const ( @@ -1903,7 +1900,7 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro // Collect usage event for managed agents. usageCollector := s.UsageCollector.Load() if usageCollector != nil { - event := usage.DCManagedAgentsV1{ + event := usagetypes.DCManagedAgentsV1{ Count: 1, } err = (*usageCollector).CollectDiscreteUsageEvent(ctx, db, event) diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 6ed6072518923..eaca618156b61 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -45,6 +45,7 @@ import ( "github.com/coder/coder/v2/coderd/schedule/cron" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/coderd/usage" + "github.com/coder/coder/v2/coderd/usage/usagetypes" "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" @@ -2678,7 +2679,7 @@ func TestCompleteJob(t *testing.T) { // Check that a usage event was collected. require.Len(t, fakeUsageCollector.collectedEvents, 1) - require.Equal(t, usage.DCManagedAgentsV1{ + require.Equal(t, usagetypes.DCManagedAgentsV1{ Count: 1, }, fakeUsageCollector.collectedEvents[0]) } else { @@ -3850,7 +3851,7 @@ func (s *fakeStream) cancel() { } type fakeUsageCollector struct { - collectedEvents []usage.Event + collectedEvents []usagetypes.Event } var _ usage.Collector = &fakeUsageCollector{} @@ -3863,7 +3864,7 @@ func newFakeUsageCollector() (*fakeUsageCollector, *atomic.Pointer[usage.Collect return fake, ptr } -func (f *fakeUsageCollector) CollectDiscreteUsageEvent(_ context.Context, _ database.Store, event usage.DiscreteEvent) error { +func (f *fakeUsageCollector) CollectDiscreteUsageEvent(_ context.Context, _ database.Store, event usagetypes.DiscreteEvent) error { f.collectedEvents = append(f.collectedEvents, event) return nil } diff --git a/coderd/usage/collector.go b/coderd/usage/collector.go index 1a2e16ea43f01..59a3afc214e06 100644 --- a/coderd/usage/collector.go +++ b/coderd/usage/collector.go @@ -4,13 +4,14 @@ import ( "context" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/usage/usagetypes" ) // Collector is a sink for usage events generated by the product. type Collector interface { // CollectDiscreteUsageEvent writes a discrete usage event to the database // with the given database or transaction. - CollectDiscreteUsageEvent(ctx context.Context, db database.Store, event DiscreteEvent) error + CollectDiscreteUsageEvent(ctx context.Context, db database.Store, event usagetypes.DiscreteEvent) error } // AGPLCollector is a no-op implementation of Collector. @@ -24,6 +25,6 @@ func NewAGPLCollector() Collector { // CollectDiscreteUsageEvent is a no-op implementation of // CollectDiscreteUsageEvent. -func (AGPLCollector) CollectDiscreteUsageEvent(_ context.Context, _ database.Store, _ DiscreteEvent) error { +func (AGPLCollector) CollectDiscreteUsageEvent(_ context.Context, _ database.Store, _ usagetypes.DiscreteEvent) error { return nil } diff --git a/coderd/usage/events.go b/coderd/usage/events.go deleted file mode 100644 index 705435a96a244..0000000000000 --- a/coderd/usage/events.go +++ /dev/null @@ -1,47 +0,0 @@ -package usage - -import ( - "golang.org/x/xerrors" - - "github.com/coder/coder/v2/coderd/database" -) - -// Event is a usage event that can be collected by the usage collector. -// -// Note that the following event types should not be updated once they are -// merged into the product. Please consult Dean before making any changes. -type Event interface { - usageEvent() // to prevent external types from implementing this interface - EventType() database.UsageEventType - Valid() error -} - -// DiscreteEvent is a usage event that is collected as a discrete event. -type DiscreteEvent interface { - Event - discreteUsageEvent() // marker method, also prevents external types from implementing this interface -} - -// DCManagedAgentsV1 is a discrete usage event for the number of managed agents. -// This event is sent in the following situations: -// - Once on first startup after usage tracking is added to the product with -// the count of all existing managed agents (count=N) -// - A new managed agent is created (count=1) -type DCManagedAgentsV1 struct { - Count uint64 `json:"count"` -} - -var _ DiscreteEvent = DCManagedAgentsV1{} - -func (DCManagedAgentsV1) usageEvent() {} -func (DCManagedAgentsV1) discreteUsageEvent() {} -func (DCManagedAgentsV1) EventType() database.UsageEventType { - return database.UsageEventTypeDcManagedAgentsV1 -} - -func (e DCManagedAgentsV1) Valid() error { - if e.Count == 0 { - return xerrors.New("count must be greater than 0") - } - return nil -} diff --git a/coderd/usage/usagetypes/events.go b/coderd/usage/usagetypes/events.go new file mode 100644 index 0000000000000..4d066cbccb44c --- /dev/null +++ b/coderd/usage/usagetypes/events.go @@ -0,0 +1,119 @@ +// Package usagetypes contains the types for usage events. These are kept in +// their own package to avoid importing any real code from coderd. +// +// Imports in this package should be limited to the standard library and the +// following packages ONLY: +// - github.com/google/uuid +// - golang.org/x/xerrors +// +// This package is imported by the Tallyman codebase. +package usagetypes + +// Please read the package documentation before adding imports. +import ( + "bytes" + "encoding/json" + "strings" + + "golang.org/x/xerrors" +) + +// UsageEventType is an enum of all usage event types. It mirrors the database +// type `usage_event_type`. +type UsageEventType string + +const ( + UsageEventTypeDCManagedAgentsV1 UsageEventType = "dc_managed_agents_v1" +) + +func (e UsageEventType) Valid() bool { + switch e { + case UsageEventTypeDCManagedAgentsV1: + return true + default: + return false + } +} + +func (e UsageEventType) IsDiscrete() bool { + return e.Valid() && strings.HasPrefix(string(e), "dc_") +} + +func (e UsageEventType) IsHeartbeat() bool { + return e.Valid() && strings.HasPrefix(string(e), "hb_") +} + +// ParseEvent parses the raw event data into the specified Go type. It fails if +// there is any unknown fields or extra data after the event. The returned event +// is validated. +func ParseEvent[T Event](data json.RawMessage) (T, error) { + dec := json.NewDecoder(bytes.NewReader(data)) + dec.DisallowUnknownFields() + + var event T + err := dec.Decode(&event) + if err != nil { + return event, xerrors.Errorf("unmarshal %T event: %w", event, err) + } + if dec.More() { + return event, xerrors.Errorf("extra data after %T event", event) + } + err = event.Valid() + if err != nil { + return event, xerrors.Errorf("invalid %T event: %w", event, err) + } + + return event, nil +} + +// ParseEventWithType parses the raw event data into the specified Go type. It +// fails if there is any unknown fields or extra data after the event. The +// returned event is validated. +func ParseEventWithType(eventType UsageEventType, data json.RawMessage) (Event, error) { + switch eventType { + case UsageEventTypeDCManagedAgentsV1: + return ParseEvent[DCManagedAgentsV1](data) + default: + return nil, xerrors.Errorf("unknown event type: %s", eventType) + } +} + +// Event is a usage event that can be collected by the usage collector. +// +// Note that the following event types should not be updated once they are +// merged into the product. Please consult Dean before making any changes. +type Event interface { + usageEvent() // to prevent external types from implementing this interface + EventType() UsageEventType + Valid() error +} + +// DiscreteEvent is a usage event that is collected as a discrete event. +type DiscreteEvent interface { + Event + discreteUsageEvent() // marker method, also prevents external types from implementing this interface +} + +// DCManagedAgentsV1 is a discrete usage event for the number of managed agents. +// This event is sent in the following situations: +// - Once on first startup after usage tracking is added to the product with +// the count of all existing managed agents (count=N) +// - A new managed agent is created (count=1) +type DCManagedAgentsV1 struct { + Count uint64 `json:"count"` +} + +var _ DiscreteEvent = DCManagedAgentsV1{} + +func (DCManagedAgentsV1) usageEvent() {} +func (DCManagedAgentsV1) discreteUsageEvent() {} +func (DCManagedAgentsV1) EventType() UsageEventType { + return UsageEventTypeDCManagedAgentsV1 +} + +func (e DCManagedAgentsV1) Valid() error { + if e.Count == 0 { + return xerrors.New("count must be greater than 0") + } + return nil +} diff --git a/coderd/usage/usagetypes/events_test.go b/coderd/usage/usagetypes/events_test.go new file mode 100644 index 0000000000000..9747f243d6c41 --- /dev/null +++ b/coderd/usage/usagetypes/events_test.go @@ -0,0 +1,59 @@ +package usagetypes_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/usage/usagetypes" +) + +func TestParseEvent(t *testing.T) { + t.Parallel() + + t.Run("ExtraFields", func(t *testing.T) { + t.Parallel() + _, err := usagetypes.ParseEvent[usagetypes.DCManagedAgentsV1]([]byte(`{"count": 1, "extra": "field"}`)) + require.ErrorContains(t, err, "unmarshal usagetypes.DCManagedAgentsV1 event") + }) + + t.Run("ExtraData", func(t *testing.T) { + t.Parallel() + _, err := usagetypes.ParseEvent[usagetypes.DCManagedAgentsV1]([]byte(`{"count": 1}{"count": 2}`)) + require.ErrorContains(t, err, "extra data after usagetypes.DCManagedAgentsV1 event") + }) + + t.Run("DCManagedAgentsV1", func(t *testing.T) { + t.Parallel() + + event, err := usagetypes.ParseEvent[usagetypes.DCManagedAgentsV1]([]byte(`{"count": 1}`)) + require.NoError(t, err) + require.Equal(t, usagetypes.DCManagedAgentsV1{Count: 1}, event) + + _, err = usagetypes.ParseEvent[usagetypes.DCManagedAgentsV1]([]byte(`{"count": "invalid"}`)) + require.ErrorContains(t, err, "unmarshal usagetypes.DCManagedAgentsV1 event") + + _, err = usagetypes.ParseEvent[usagetypes.DCManagedAgentsV1]([]byte(`{}`)) + require.ErrorContains(t, err, "invalid usagetypes.DCManagedAgentsV1 event: count must be greater than 0") + }) +} + +func TestParseEventWithType(t *testing.T) { + t.Parallel() + + t.Run("UnknownEvent", func(t *testing.T) { + t.Parallel() + _, err := usagetypes.ParseEventWithType(usagetypes.UsageEventType("fake"), []byte(`{}`)) + require.ErrorContains(t, err, "unknown event type: fake") + }) + + t.Run("DCManagedAgentsV1", func(t *testing.T) { + t.Parallel() + + eventType := usagetypes.UsageEventTypeDCManagedAgentsV1 + event, err := usagetypes.ParseEventWithType(eventType, []byte(`{"count": 1}`)) + require.NoError(t, err) + require.Equal(t, usagetypes.DCManagedAgentsV1{Count: 1}, event) + require.Equal(t, eventType, event.EventType()) + }) +} diff --git a/coderd/usage/usagetypes/tallyman.go b/coderd/usage/usagetypes/tallyman.go new file mode 100644 index 0000000000000..38358b7a6d518 --- /dev/null +++ b/coderd/usage/usagetypes/tallyman.go @@ -0,0 +1,70 @@ +package usagetypes + +// Please read the package documentation before adding imports. +import ( + "encoding/json" + "time" + + "golang.org/x/xerrors" +) + +const ( + TallymanCoderLicenseKeyHeader = "Coder-License-Key" + TallymanCoderDeploymentIDHeader = "Coder-Deployment-ID" +) + +// TallymanV1Response is a generic response with a message from the Tallyman +// API. It is typically returned when there is an error. +type TallymanV1Response struct { + Message string `json:"message"` +} + +// TallymanV1IngestRequest is a request to the Tallyman API to ingest usage +// events. +type TallymanV1IngestRequest struct { + Events []TallymanV1IngestEvent `json:"events"` +} + +// TallymanV1IngestEvent is an event to be ingested into the Tallyman API. +type TallymanV1IngestEvent struct { + ID string `json:"id"` + EventType UsageEventType `json:"event_type"` + EventData json.RawMessage `json:"event_data"` + CreatedAt time.Time `json:"created_at"` +} + +// Valid validates the TallymanV1IngestEvent. It does not validate the event +// body. +func (e TallymanV1IngestEvent) Valid() error { + if e.ID == "" { + return xerrors.New("id is required") + } + if !e.EventType.Valid() { + return xerrors.Errorf("event_type %q is invalid", e.EventType) + } + if e.CreatedAt.IsZero() { + return xerrors.New("created_at cannot be zero") + } + return nil +} + +// TallymanV1IngestResponse is a response from the Tallyman API to ingest usage +// events. +type TallymanV1IngestResponse struct { + AcceptedEvents []TallymanV1IngestAcceptedEvent `json:"accepted_events"` + RejectedEvents []TallymanV1IngestRejectedEvent `json:"rejected_events"` +} + +// TallymanV1IngestAcceptedEvent is an event that was accepted by the Tallyman +// API. +type TallymanV1IngestAcceptedEvent struct { + ID string `json:"id"` +} + +// TallymanV1IngestRejectedEvent is an event that was rejected by the Tallyman +// API. +type TallymanV1IngestRejectedEvent struct { + ID string `json:"id"` + Message string `json:"message"` + Permanent bool `json:"permanent"` +} diff --git a/coderd/usage/usagetypes/tallyman_test.go b/coderd/usage/usagetypes/tallyman_test.go new file mode 100644 index 0000000000000..f8f09446dff51 --- /dev/null +++ b/coderd/usage/usagetypes/tallyman_test.go @@ -0,0 +1,85 @@ +package usagetypes_test + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/usage/usagetypes" +) + +func TestTallymanV1UsageEvent(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + event usagetypes.TallymanV1IngestEvent + errorMessage string + }{ + { + name: "OK", + event: usagetypes.TallymanV1IngestEvent{ + ID: "123", + EventType: usagetypes.UsageEventTypeDCManagedAgentsV1, + // EventData is not validated. + EventData: json.RawMessage{}, + CreatedAt: time.Now(), + }, + errorMessage: "", + }, + { + name: "NoID", + event: usagetypes.TallymanV1IngestEvent{ + EventType: usagetypes.UsageEventTypeDCManagedAgentsV1, + EventData: json.RawMessage{}, + CreatedAt: time.Now(), + }, + errorMessage: "id is required", + }, + { + name: "NoEventType", + event: usagetypes.TallymanV1IngestEvent{ + ID: "123", + EventType: usagetypes.UsageEventType(""), + EventData: json.RawMessage{}, + CreatedAt: time.Now(), + }, + errorMessage: `event_type "" is invalid`, + }, + { + name: "UnknownEventType", + event: usagetypes.TallymanV1IngestEvent{ + ID: "123", + EventType: usagetypes.UsageEventType("unknown"), + EventData: json.RawMessage{}, + CreatedAt: time.Now(), + }, + errorMessage: `event_type "unknown" is invalid`, + }, + { + name: "NoCreatedAt", + event: usagetypes.TallymanV1IngestEvent{ + ID: "123", + EventType: usagetypes.UsageEventTypeDCManagedAgentsV1, + EventData: json.RawMessage{}, + CreatedAt: time.Time{}, + }, + errorMessage: "created_at cannot be zero", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := tc.event.Valid() + if tc.errorMessage == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, tc.errorMessage) + } + }) + } +} diff --git a/enterprise/coderd/usage/collector.go b/enterprise/coderd/usage/collector.go index e25ee76710190..eb53b7d051c35 100644 --- a/enterprise/coderd/usage/collector.go +++ b/enterprise/coderd/usage/collector.go @@ -10,6 +10,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" agplusage "github.com/coder/coder/v2/coderd/usage" + "github.com/coder/coder/v2/coderd/usage/usagetypes" "github.com/coder/quartz" ) @@ -42,7 +43,7 @@ func CollectorWithClock(clock quartz.Clock) CollectorOption { } // CollectDiscreteUsageEvent implements agplusage.Collector. -func (c *dbCollector) CollectDiscreteUsageEvent(ctx context.Context, db database.Store, event agplusage.DiscreteEvent) error { +func (c *dbCollector) CollectDiscreteUsageEvent(ctx context.Context, db database.Store, event usagetypes.DiscreteEvent) error { if !event.EventType().IsDiscrete() { return xerrors.Errorf("event type %q is not a discrete event", event.EventType()) } @@ -55,12 +56,17 @@ func (c *dbCollector) CollectDiscreteUsageEvent(ctx context.Context, db database return xerrors.Errorf("marshal event as JSON: %w", err) } + dbEventType := database.UsageEventType(event.EventType()) + if !dbEventType.Valid() { + return xerrors.Errorf("event type %q is invalid", event.EventType()) + } + // Duplicate events are ignored by the query, so we don't need to check the // error. return db.InsertUsageEvent(ctx, database.InsertUsageEventParams{ // Always generate a new UUID for discrete events. ID: uuid.New().String(), - EventType: event.EventType(), + EventType: database.UsageEventType(event.EventType()), EventData: jsonData, CreatedAt: dbtime.Time(c.clock.Now()), }) diff --git a/enterprise/coderd/usage/collector_test.go b/enterprise/coderd/usage/collector_test.go index 85c323c864db1..e703de8f46d0d 100644 --- a/enterprise/coderd/usage/collector_test.go +++ b/enterprise/coderd/usage/collector_test.go @@ -12,7 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtime" - agplusage "github.com/coder/coder/v2/coderd/usage" + "github.com/coder/coder/v2/coderd/usage/usagetypes" "github.com/coder/coder/v2/enterprise/coderd/usage" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" @@ -33,37 +33,37 @@ func TestCollector(t *testing.T) { now := dbtime.Now() events := []struct { time time.Time - event agplusage.DiscreteEvent + event usagetypes.DiscreteEvent }{ { time: now, - event: agplusage.DCManagedAgentsV1{ + event: usagetypes.DCManagedAgentsV1{ Count: 1, }, }, { time: now.Add(1 * time.Minute), - event: agplusage.DCManagedAgentsV1{ + event: usagetypes.DCManagedAgentsV1{ Count: 2, }, }, } - for _, event := range events { - eventJSON := jsoninate(t, event.event) + for _, e := range events { + eventJSON := jsoninate(t, e.event) db.EXPECT().InsertUsageEvent(ctx, gomock.Any()).DoAndReturn( func(ctx interface{}, params database.InsertUsageEventParams) error { _, err := uuid.Parse(params.ID) assert.NoError(t, err) - assert.Equal(t, event.event.EventType(), params.EventType) + assert.Equal(t, e.event.EventType(), usagetypes.UsageEventType(params.EventType)) assert.JSONEq(t, eventJSON, string(params.EventData)) - assert.Equal(t, event.time, params.CreatedAt) + assert.Equal(t, e.time, params.CreatedAt) return nil }, ).Times(1) - clock.Set(event.time) - err := collector.CollectDiscreteUsageEvent(ctx, db, event.event) + clock.Set(e.time) + err := collector.CollectDiscreteUsageEvent(ctx, db, e.event) require.NoError(t, err) } }) @@ -77,7 +77,7 @@ func TestCollector(t *testing.T) { // We should get an error if the event is invalid. collector := usage.NewDBCollector() - err := collector.CollectDiscreteUsageEvent(ctx, db, agplusage.DCManagedAgentsV1{ + err := collector.CollectDiscreteUsageEvent(ctx, db, usagetypes.DCManagedAgentsV1{ Count: 0, // invalid }) assert.ErrorContains(t, err, `invalid "dc_managed_agents_v1" event: count must be greater than 0`) diff --git a/enterprise/coderd/usage/publisher.go b/enterprise/coderd/usage/publisher.go index 57b41cbc9cbb5..cad9682bfc0f0 100644 --- a/enterprise/coderd/usage/publisher.go +++ b/enterprise/coderd/usage/publisher.go @@ -16,14 +16,13 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/usage/usagetypes" "github.com/coder/coder/v2/cryptorand" "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/quartz" ) const ( - CoderLicenseJWTHeader = "Coder-License-JWT" - tallymanURL = "https://tallyman-ingress.coder.com" tallymanIngestURLV1 = tallymanURL + "/api/v1/ingest" @@ -209,16 +208,15 @@ func (p *tallymanPublisher) publishOnce(ctx context.Context, deploymentID uuid.U var ( eventIDs = make(map[string]struct{}) - tallymanReq = TallymanIngestRequestV1{ - DeploymentID: deploymentID, - Events: make([]TallymanIngestEventV1, 0, len(events)), + tallymanReq = usagetypes.TallymanV1IngestRequest{ + Events: make([]usagetypes.TallymanV1IngestEvent, 0, len(events)), } ) for _, event := range events { eventIDs[event.ID] = struct{}{} - tallymanReq.Events = append(tallymanReq.Events, TallymanIngestEventV1{ + tallymanReq.Events = append(tallymanReq.Events, usagetypes.TallymanV1IngestEvent{ ID: event.ID, - EventType: event.EventType, + EventType: usagetypes.UsageEventType(event.EventType), EventData: event.EventData, CreatedAt: event.CreatedAt, }) @@ -229,17 +227,17 @@ func (p *tallymanPublisher) publishOnce(ctx context.Context, deploymentID uuid.U return 0, xerrors.Errorf("duplicate event IDs found in events for publishing") } - resp, err := p.sendPublishRequest(ctx, licenseJwt, tallymanReq) + resp, err := p.sendPublishRequest(ctx, deploymentID, licenseJwt, tallymanReq) allFailed := err != nil if err != nil { p.log.Warn(ctx, "failed to send publish request to tallyman", slog.F("count", len(events)), slog.Error(err)) // Fake a response with all events temporarily rejected. - resp = TallymanIngestResponseV1{ - AcceptedEvents: []TallymanIngestAcceptedEventV1{}, - RejectedEvents: make([]TallymanIngestRejectedEventV1, len(events)), + resp = usagetypes.TallymanV1IngestResponse{ + AcceptedEvents: []usagetypes.TallymanV1IngestAcceptedEvent{}, + RejectedEvents: make([]usagetypes.TallymanV1IngestRejectedEvent, len(events)), } for i, event := range events { - resp.RejectedEvents[i] = TallymanIngestRejectedEventV1{ + resp.RejectedEvents[i] = usagetypes.TallymanV1IngestRejectedEvent{ ID: event.ID, Message: fmt.Sprintf("failed to publish to tallyman: %v", err), Permanent: false, @@ -254,8 +252,8 @@ func (p *tallymanPublisher) publishOnce(ctx context.Context, deploymentID uuid.U } var ( - acceptedEvents = make(map[string]*TallymanIngestAcceptedEventV1) - rejectedEvents = make(map[string]*TallymanIngestRejectedEventV1) + acceptedEvents = make(map[string]*usagetypes.TallymanV1IngestAcceptedEvent) + rejectedEvents = make(map[string]*usagetypes.TallymanV1IngestRejectedEvent) ) for _, event := range resp.AcceptedEvents { acceptedEvents[event.ID] = &event @@ -350,37 +348,38 @@ func (p *tallymanPublisher) getBestLicenseJWT(ctx context.Context) (string, erro return bestLicense.Raw, nil } -func (p *tallymanPublisher) sendPublishRequest(ctx context.Context, licenseJwt string, req TallymanIngestRequestV1) (TallymanIngestResponseV1, error) { +func (p *tallymanPublisher) sendPublishRequest(ctx context.Context, deploymentID uuid.UUID, licenseJwt string, req usagetypes.TallymanV1IngestRequest) (usagetypes.TallymanV1IngestResponse, error) { body, err := json.Marshal(req) if err != nil { - return TallymanIngestResponseV1{}, err + return usagetypes.TallymanV1IngestResponse{}, err } r, err := http.NewRequestWithContext(ctx, http.MethodPost, p.ingestURL, bytes.NewReader(body)) if err != nil { - return TallymanIngestResponseV1{}, err + return usagetypes.TallymanV1IngestResponse{}, err } - r.Header.Set(CoderLicenseJWTHeader, licenseJwt) + r.Header.Set(usagetypes.TallymanCoderLicenseKeyHeader, licenseJwt) + r.Header.Set(usagetypes.TallymanCoderDeploymentIDHeader, deploymentID.String()) resp, err := p.httpClient.Do(r) if err != nil { - return TallymanIngestResponseV1{}, err + return usagetypes.TallymanV1IngestResponse{}, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - var errBody TallymanErrorV1 + var errBody usagetypes.TallymanV1Response if err := json.NewDecoder(resp.Body).Decode(&errBody); err != nil { - errBody = TallymanErrorV1{ + errBody = usagetypes.TallymanV1Response{ Message: fmt.Sprintf("could not decode error response body: %v", err), } } - return TallymanIngestResponseV1{}, xerrors.Errorf("unexpected status code %v, error: %s", resp.StatusCode, errBody.Message) + return usagetypes.TallymanV1IngestResponse{}, xerrors.Errorf("unexpected status code %v, error: %s", resp.StatusCode, errBody.Message) } - var respBody TallymanIngestResponseV1 + var respBody usagetypes.TallymanV1IngestResponse if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return TallymanIngestResponseV1{}, xerrors.Errorf("decode response body: %w", err) + return usagetypes.TallymanV1IngestResponse{}, xerrors.Errorf("decode response body: %w", err) } return respBody, nil @@ -392,34 +391,3 @@ func (p *tallymanPublisher) Close() error { <-p.done return nil } - -type TallymanErrorV1 struct { - Message string `json:"message"` -} - -type TallymanIngestRequestV1 struct { - DeploymentID uuid.UUID `json:"deployment_id"` - Events []TallymanIngestEventV1 `json:"events"` -} - -type TallymanIngestEventV1 struct { - ID string `json:"id"` - EventType database.UsageEventType `json:"event_type"` - EventData json.RawMessage `json:"event_data"` - CreatedAt time.Time `json:"created_at"` -} - -type TallymanIngestResponseV1 struct { - AcceptedEvents []TallymanIngestAcceptedEventV1 `json:"accepted_events"` - RejectedEvents []TallymanIngestRejectedEventV1 `json:"rejected_events"` -} - -type TallymanIngestAcceptedEventV1 struct { - ID string `json:"id"` -} - -type TallymanIngestRejectedEventV1 struct { - ID string `json:"id"` - Message string `json:"message"` - Permanent bool `json:"permanent"` -} diff --git a/enterprise/coderd/usage/publisher_test.go b/enterprise/coderd/usage/publisher_test.go index 96c25a6eb4ee2..9f99ec8abd1ed 100644 --- a/enterprise/coderd/usage/publisher_test.go +++ b/enterprise/coderd/usage/publisher_test.go @@ -21,7 +21,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" - agplusage "github.com/coder/coder/v2/coderd/usage" + "github.com/coder/coder/v2/coderd/usage/usagetypes" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/usage" "github.com/coder/coder/v2/testutil" @@ -47,16 +47,15 @@ func TestIntegration(t *testing.T) { var ( calls int64 - handler func(req usage.TallymanIngestRequestV1) any + handler func(req usagetypes.TallymanV1IngestRequest) any ) - ingestURL := fakeServer(t, tallymanHandler(t, licenseJWT, func(req usage.TallymanIngestRequestV1) any { + ingestURL := fakeServer(t, tallymanHandler(t, deploymentID.String(), licenseJWT, func(req usagetypes.TallymanV1IngestRequest) any { callCount := atomic.AddInt64(&calls, 1) t.Logf("tallyman backend received call %d", callCount) - assert.Equal(t, deploymentID, req.DeploymentID) if handler == nil { t.Errorf("handler is nil") - return usage.TallymanIngestResponseV1{} + return usagetypes.TallymanV1IngestResponse{} } return handler(req) })) @@ -66,7 +65,7 @@ func TestIntegration(t *testing.T) { ) // Insert an old event that should never be published. clock.Set(now.Add(-31 * 24 * time.Hour)) - err := collector.CollectDiscreteUsageEvent(ctx, db, agplusage.DCManagedAgentsV1{ + err := collector.CollectDiscreteUsageEvent(ctx, db, usagetypes.DCManagedAgentsV1{ Count: 31, }) require.NoError(t, err) @@ -75,7 +74,7 @@ func TestIntegration(t *testing.T) { clock.Set(now.Add(1 * time.Second)) for i := 0; i < eventCount; i++ { clock.Advance(time.Second) - err := collector.CollectDiscreteUsageEvent(ctx, db, agplusage.DCManagedAgentsV1{ + err := collector.CollectDiscreteUsageEvent(ctx, db, usagetypes.DCManagedAgentsV1{ Count: uint64(i + 1), // nolint:gosec // these numbers are tiny and will not overflow }) require.NoErrorf(t, err, "collecting event %d", i) @@ -110,33 +109,33 @@ func TestIntegration(t *testing.T) { // first event, temporarily reject the second, and permanently reject the // third. var temporarilyRejectedEventID string - handler = func(req usage.TallymanIngestRequestV1) any { + handler = func(req usagetypes.TallymanV1IngestRequest) any { // On the first call, accept the first event, temporarily reject the // second, and permanently reject the third. - acceptedEvents := make([]usage.TallymanIngestAcceptedEventV1, 1) - rejectedEvents := make([]usage.TallymanIngestRejectedEventV1, 2) + acceptedEvents := make([]usagetypes.TallymanV1IngestAcceptedEvent, 1) + rejectedEvents := make([]usagetypes.TallymanV1IngestRejectedEvent, 2) if assert.Len(t, req.Events, eventCount) { - assert.JSONEqf(t, jsoninate(t, agplusage.DCManagedAgentsV1{ + assert.JSONEqf(t, jsoninate(t, usagetypes.DCManagedAgentsV1{ Count: 1, }), string(req.Events[0].EventData), "event data did not match for event %d", 0) acceptedEvents[0].ID = req.Events[0].ID temporarilyRejectedEventID = req.Events[1].ID - assert.JSONEqf(t, jsoninate(t, agplusage.DCManagedAgentsV1{ + assert.JSONEqf(t, jsoninate(t, usagetypes.DCManagedAgentsV1{ Count: 2, }), string(req.Events[1].EventData), "event data did not match for event %d", 1) rejectedEvents[0].ID = req.Events[1].ID rejectedEvents[0].Message = "temporarily rejected" rejectedEvents[0].Permanent = false - assert.JSONEqf(t, jsoninate(t, agplusage.DCManagedAgentsV1{ + assert.JSONEqf(t, jsoninate(t, usagetypes.DCManagedAgentsV1{ Count: 3, }), string(req.Events[2].EventData), "event data did not match for event %d", 2) rejectedEvents[1].ID = req.Events[2].ID rejectedEvents[1].Message = "permanently rejected" rejectedEvents[1].Permanent = true } - return usage.TallymanIngestResponseV1{ + return usagetypes.TallymanV1IngestResponse{ AcceptedEvents: acceptedEvents, RejectedEvents: rejectedEvents, } @@ -155,16 +154,16 @@ func TestIntegration(t *testing.T) { // Set the handler for the next publish call. This call should only include // the temporarily rejected event from earlier. This time we'll accept it. - handler = func(req usage.TallymanIngestRequestV1) any { + handler = func(req usagetypes.TallymanV1IngestRequest) any { assert.Len(t, req.Events, 1) - acceptedEvents := make([]usage.TallymanIngestAcceptedEventV1, len(req.Events)) + acceptedEvents := make([]usagetypes.TallymanV1IngestAcceptedEvent, len(req.Events)) for i, event := range req.Events { assert.Equal(t, temporarilyRejectedEventID, event.ID) acceptedEvents[i].ID = event.ID } - return usage.TallymanIngestResponseV1{ + return usagetypes.TallymanV1IngestResponse{ AcceptedEvents: acceptedEvents, - RejectedEvents: []usage.TallymanIngestRejectedEventV1{}, + RejectedEvents: []usagetypes.TallymanV1IngestRejectedEvent{}, } } @@ -204,11 +203,11 @@ func TestPublisherNoEligibleLicenses(t *testing.T) { db.EXPECT().GetDeploymentID(gomock.Any()).Return(deploymentID.String(), nil).Times(1) var calls int64 - ingestURL := fakeServer(t, tallymanHandler(t, "", func(req usage.TallymanIngestRequestV1) any { + ingestURL := fakeServer(t, tallymanHandler(t, deploymentID.String(), "", func(req usagetypes.TallymanV1IngestRequest) any { atomic.AddInt64(&calls, 1) - return usage.TallymanIngestResponseV1{ - AcceptedEvents: []usage.TallymanIngestAcceptedEventV1{}, - RejectedEvents: []usage.TallymanIngestRejectedEventV1{}, + return usagetypes.TallymanV1IngestResponse{ + AcceptedEvents: []usagetypes.TallymanV1IngestAcceptedEvent{}, + RejectedEvents: []usagetypes.TallymanV1IngestRejectedEvent{}, } })) @@ -273,11 +272,11 @@ func TestPublisherClaimExpiry(t *testing.T) { log := slogtest.Make(t, nil) db, _ := dbtestutil.NewDB(t) clock := quartz.NewMock(t) - _, licenseJWT := configureDeployment(ctx, t, db) + deploymentID, licenseJWT := configureDeployment(ctx, t, db) now := time.Now() var calls int64 - ingestURL := fakeServer(t, tallymanHandler(t, licenseJWT, func(req usage.TallymanIngestRequestV1) any { + ingestURL := fakeServer(t, tallymanHandler(t, deploymentID.String(), licenseJWT, func(req usagetypes.TallymanV1IngestRequest) any { atomic.AddInt64(&calls, 1) return tallymanAcceptAllHandler(req) })) @@ -296,7 +295,7 @@ func TestPublisherClaimExpiry(t *testing.T) { // Create an event that was claimed 1h-18m ago. The ticker has a forced // delay of 17m in this test. clock.Set(now) - err := collector.CollectDiscreteUsageEvent(ctx, db, agplusage.DCManagedAgentsV1{ + err := collector.CollectDiscreteUsageEvent(ctx, db, usagetypes.DCManagedAgentsV1{ Count: 1, }) require.NoError(t, err) @@ -351,17 +350,17 @@ func TestPublisherMissingEvents(t *testing.T) { log := slogtest.Make(t, nil) ctrl := gomock.NewController(t) db := dbmock.NewMockStore(ctrl) - _, licenseJWT := configureMockDeployment(t, db) + deploymentID, licenseJWT := configureMockDeployment(t, db) clock := quartz.NewMock(t) now := time.Now() clock.Set(now) var calls int64 - ingestURL := fakeServer(t, tallymanHandler(t, licenseJWT, func(req usage.TallymanIngestRequestV1) any { + ingestURL := fakeServer(t, tallymanHandler(t, deploymentID.String(), licenseJWT, func(req usagetypes.TallymanV1IngestRequest) any { atomic.AddInt64(&calls, 1) - return usage.TallymanIngestResponseV1{ - AcceptedEvents: []usage.TallymanIngestAcceptedEventV1{}, - RejectedEvents: []usage.TallymanIngestRejectedEventV1{}, + return usagetypes.TallymanV1IngestResponse{ + AcceptedEvents: []usagetypes.TallymanV1IngestAcceptedEvent{}, + RejectedEvents: []usagetypes.TallymanV1IngestRejectedEvent{}, } })) @@ -376,7 +375,7 @@ func TestPublisherMissingEvents(t *testing.T) { { ID: uuid.New().String(), EventType: database.UsageEventTypeDcManagedAgentsV1, - EventData: []byte(jsoninate(t, agplusage.DCManagedAgentsV1{ + EventData: []byte(jsoninate(t, usagetypes.DCManagedAgentsV1{ Count: 1, })), CreatedAt: now, @@ -476,9 +475,8 @@ func TestPublisherLicenseSelection(t *testing.T) { }, nil) var calls int64 - ingestURL := fakeServer(t, tallymanHandler(t, expectedLicense, func(req usage.TallymanIngestRequestV1) any { + ingestURL := fakeServer(t, tallymanHandler(t, deploymentID.String(), expectedLicense, func(req usagetypes.TallymanV1IngestRequest) any { atomic.AddInt64(&calls, 1) - assert.Equal(t, deploymentID, req.DeploymentID) return tallymanAcceptAllHandler(req) })) @@ -505,7 +503,7 @@ func TestPublisherLicenseSelection(t *testing.T) { { ID: uuid.New().String(), EventType: database.UsageEventTypeDcManagedAgentsV1, - EventData: []byte(jsoninate(t, agplusage.DCManagedAgentsV1{ + EventData: []byte(jsoninate(t, usagetypes.DCManagedAgentsV1{ Count: 1, })), }, @@ -540,12 +538,12 @@ func TestPublisherTallymanError(t *testing.T) { now := time.Now() clock.Set(now) - _, licenseJWT := configureMockDeployment(t, db) + deploymentID, licenseJWT := configureMockDeployment(t, db) const errorMessage = "tallyman error" var calls int64 - ingestURL := fakeServer(t, tallymanHandler(t, licenseJWT, func(req usage.TallymanIngestRequestV1) any { + ingestURL := fakeServer(t, tallymanHandler(t, deploymentID.String(), licenseJWT, func(req usagetypes.TallymanV1IngestRequest) any { atomic.AddInt64(&calls, 1) - return usage.TallymanErrorV1{ + return usagetypes.TallymanV1Response{ Message: errorMessage, } })) @@ -573,7 +571,7 @@ func TestPublisherTallymanError(t *testing.T) { { ID: uuid.New().String(), EventType: database.UsageEventTypeDcManagedAgentsV1, - EventData: []byte(jsoninate(t, agplusage.DCManagedAgentsV1{ + EventData: []byte(jsoninate(t, usagetypes.DCManagedAgentsV1{ Count: 1, })), }, @@ -653,44 +651,61 @@ func fakeServer(t *testing.T, handler http.Handler) string { return server.URL } -func tallymanHandler(t *testing.T, expectLicenseJWT string, handler func(req usage.TallymanIngestRequestV1) any) http.Handler { +func tallymanHandler(t *testing.T, expectDeploymentID string, expectLicenseJWT string, handler func(req usagetypes.TallymanV1IngestRequest) any) http.Handler { t.Helper() return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { t.Helper() - licenseJWT := r.Header.Get(usage.CoderLicenseJWTHeader) + licenseJWT := r.Header.Get(usagetypes.TallymanCoderLicenseKeyHeader) if expectLicenseJWT != "" && !assert.Equal(t, expectLicenseJWT, licenseJWT, "license JWT in request did not match") { rw.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(rw).Encode(usage.TallymanErrorV1{ + _ = json.NewEncoder(rw).Encode(usagetypes.TallymanV1Response{ Message: "license JWT in request did not match", }) - require.NoError(t, err) return } - var req usage.TallymanIngestRequestV1 + deploymentID := r.Header.Get(usagetypes.TallymanCoderDeploymentIDHeader) + if expectDeploymentID != "" && !assert.Equal(t, expectDeploymentID, deploymentID, "deployment ID in request did not match") { + rw.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(rw).Encode(usagetypes.TallymanV1Response{ + Message: "deployment ID in request did not match", + }) + return + } + + var req usagetypes.TallymanV1IngestRequest err := json.NewDecoder(r.Body).Decode(&req) - require.NoError(t, err) + if !assert.NoError(t, err, "could not decode request body") { + rw.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(rw).Encode(usagetypes.TallymanV1Response{ + Message: "could not decode request body", + }) + return + } resp := handler(req) switch resp.(type) { - case usage.TallymanErrorV1: + case usagetypes.TallymanV1Response: rw.WriteHeader(http.StatusInternalServerError) default: rw.WriteHeader(http.StatusOK) } err = json.NewEncoder(rw).Encode(resp) - require.NoError(t, err) + if !assert.NoError(t, err, "could not encode response body") { + rw.WriteHeader(http.StatusInternalServerError) + return + } }) } -func tallymanAcceptAllHandler(req usage.TallymanIngestRequestV1) usage.TallymanIngestResponseV1 { - acceptedEvents := make([]usage.TallymanIngestAcceptedEventV1, len(req.Events)) +func tallymanAcceptAllHandler(req usagetypes.TallymanV1IngestRequest) usagetypes.TallymanV1IngestResponse { + acceptedEvents := make([]usagetypes.TallymanV1IngestAcceptedEvent, len(req.Events)) for i, event := range req.Events { acceptedEvents[i].ID = event.ID } - return usage.TallymanIngestResponseV1{ + return usagetypes.TallymanV1IngestResponse{ AcceptedEvents: acceptedEvents, - RejectedEvents: []usage.TallymanIngestRejectedEventV1{}, + RejectedEvents: []usagetypes.TallymanV1IngestRejectedEvent{}, } }