Skip to content

Commit ce93565

Browse files
authored
test: start migrating dbauthz tests to mocked db (#19257)
This PR adds a framework to move to a mocked db. And therefore massively speed up these tests.
1 parent 155c7bb commit ce93565

File tree

6 files changed

+186
-17
lines changed

6 files changed

+186
-17
lines changed

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ import (
1111
"testing"
1212
"time"
1313

14+
"github.com/brianvoe/gofakeit/v7"
1415
"github.com/google/uuid"
1516
"github.com/sqlc-dev/pqtype"
1617
"github.com/stretchr/testify/require"
18+
"go.uber.org/mock/gomock"
1719
"golang.org/x/xerrors"
1820

1921
"cdr.dev/slog"
@@ -22,6 +24,7 @@ import (
2224
"github.com/coder/coder/v2/coderd/database/db2sdk"
2325
"github.com/coder/coder/v2/coderd/database/dbauthz"
2426
"github.com/coder/coder/v2/coderd/database/dbgen"
27+
"github.com/coder/coder/v2/coderd/database/dbmock"
2528
"github.com/coder/coder/v2/coderd/database/dbtestutil"
2629
"github.com/coder/coder/v2/coderd/database/dbtime"
2730
"github.com/coder/coder/v2/coderd/notifications"
@@ -204,14 +207,15 @@ func defaultIPAddress() pqtype.Inet {
204207
}
205208

206209
func (s *MethodTestSuite) TestAPIKey() {
207-
s.Run("DeleteAPIKeyByID", s.Subtest(func(db database.Store, check *expects) {
208-
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
209-
key, _ := dbgen.APIKey(s.T(), db, database.APIKey{})
210+
s.Run("DeleteAPIKeyByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
211+
key := testutil.Fake(s.T(), faker, database.APIKey{})
212+
dbm.EXPECT().GetAPIKeyByID(gomock.Any(), key.ID).Return(key, nil).AnyTimes()
213+
dbm.EXPECT().DeleteAPIKeyByID(gomock.Any(), key.ID).Return(nil).AnyTimes()
210214
check.Args(key.ID).Asserts(key, policy.ActionDelete).Returns()
211215
}))
212-
s.Run("GetAPIKeyByID", s.Subtest(func(db database.Store, check *expects) {
213-
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
214-
key, _ := dbgen.APIKey(s.T(), db, database.APIKey{})
216+
s.Run("GetAPIKeyByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
217+
key := testutil.Fake(s.T(), faker, database.APIKey{})
218+
dbm.EXPECT().GetAPIKeyByID(gomock.Any(), key.ID).Return(key, nil).AnyTimes()
215219
check.Args(key.ID).Asserts(key, policy.ActionRead).Returns(key)
216220
}))
217221
s.Run("GetAPIKeyByName", s.Subtest(func(db database.Store, check *expects) {
@@ -234,14 +238,12 @@ func (s *MethodTestSuite) TestAPIKey() {
234238
Asserts(a, policy.ActionRead, b, policy.ActionRead).
235239
Returns(slice.New(a, b))
236240
}))
237-
s.Run("GetAPIKeysByUserID", s.Subtest(func(db database.Store, check *expects) {
238-
u1 := dbgen.User(s.T(), db, database.User{})
239-
u2 := dbgen.User(s.T(), db, database.User{})
240-
241-
keyA, _ := dbgen.APIKey(s.T(), db, database.APIKey{UserID: u1.ID, LoginType: database.LoginTypeToken, TokenName: "key-a"})
242-
keyB, _ := dbgen.APIKey(s.T(), db, database.APIKey{UserID: u1.ID, LoginType: database.LoginTypeToken, TokenName: "key-b"})
243-
_, _ = dbgen.APIKey(s.T(), db, database.APIKey{UserID: u2.ID, LoginType: database.LoginTypeToken})
241+
s.Run("GetAPIKeysByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
242+
u1 := testutil.Fake(s.T(), faker, database.User{})
243+
keyA := testutil.Fake(s.T(), faker, database.APIKey{UserID: u1.ID, LoginType: database.LoginTypeToken, TokenName: "key-a"})
244+
keyB := testutil.Fake(s.T(), faker, database.APIKey{UserID: u1.ID, LoginType: database.LoginTypeToken, TokenName: "key-b"})
244245

246+
dbm.EXPECT().GetAPIKeysByUserID(gomock.Any(), gomock.Any()).Return(slice.New(keyA, keyB), nil).AnyTimes()
245247
check.Args(database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: u1.ID}).
246248
Asserts(keyA, policy.ActionRead, keyB, policy.ActionRead).
247249
Returns(slice.New(keyA, keyB))

coderd/database/dbauthz/setup_test.go

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strings"
1111
"testing"
1212

13+
"github.com/brianvoe/gofakeit/v7"
1314
"github.com/google/go-cmp/cmp"
1415
"github.com/google/go-cmp/cmp/cmpopts"
1516
"github.com/google/uuid"
@@ -20,14 +21,14 @@ import (
2021
"golang.org/x/xerrors"
2122

2223
"cdr.dev/slog"
23-
"github.com/coder/coder/v2/coderd/rbac/policy"
2424

2525
"github.com/coder/coder/v2/coderd/coderdtest"
2626
"github.com/coder/coder/v2/coderd/database"
2727
"github.com/coder/coder/v2/coderd/database/dbauthz"
2828
"github.com/coder/coder/v2/coderd/database/dbmock"
2929
"github.com/coder/coder/v2/coderd/database/dbtestutil"
3030
"github.com/coder/coder/v2/coderd/rbac"
31+
"github.com/coder/coder/v2/coderd/rbac/policy"
3132
"github.com/coder/coder/v2/coderd/rbac/regosql"
3233
"github.com/coder/coder/v2/coderd/util/slice"
3334
)
@@ -105,19 +106,44 @@ func (s *MethodTestSuite) TearDownSuite() {
105106

106107
var testActorID = uuid.New()
107108

108-
// Subtest is a helper function that returns a function that can be passed to
109+
// Mocked runs a subtest with a mocked database. Removing the overhead of a real
110+
// postgres database resulting in much faster tests.
111+
func (s *MethodTestSuite) Mocked(testCaseF func(dmb *dbmock.MockStore, faker *gofakeit.Faker, check *expects)) func() {
112+
t := s.T()
113+
mDB := dbmock.NewMockStore(gomock.NewController(t))
114+
mDB.EXPECT().Wrappers().Return([]string{}).AnyTimes()
115+
116+
// Use a constant seed to prevent flakes from random data generation.
117+
faker := gofakeit.New(0)
118+
119+
// The usual Subtest assumes the test setup will use a real database to populate
120+
// with data. In this mocked case, we want to pass the underlying mocked database
121+
// to the test case instead.
122+
return s.SubtestWithDB(mDB, func(_ database.Store, check *expects) {
123+
testCaseF(mDB, faker, check)
124+
})
125+
}
126+
127+
// Subtest starts up a real postgres database for each test case.
128+
// Deprecated: Use 'Mocked' instead for much faster tests.
129+
func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expects)) func() {
130+
t := s.T()
131+
db, _ := dbtestutil.NewDB(t)
132+
return s.SubtestWithDB(db, testCaseF)
133+
}
134+
135+
// SubtestWithDB is a helper function that returns a function that can be passed to
109136
// s.Run(). This function will run the test case for the method that is being
110137
// tested. The check parameter is used to assert the results of the method.
111138
// If the caller does not use the `check` parameter, the test will fail.
112-
func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expects)) func() {
139+
func (s *MethodTestSuite) SubtestWithDB(db database.Store, testCaseF func(db database.Store, check *expects)) func() {
113140
return func() {
114141
t := s.T()
115142
testName := s.T().Name()
116143
names := strings.Split(testName, "/")
117144
methodName := names[len(names)-1]
118145
s.methodAccounting[methodName]++
119146

120-
db, _ := dbtestutil.NewDB(t)
121147
fakeAuthorizer := &coderdtest.FakeAuthorizer{}
122148
rec := &coderdtest.RecordingAuthorizer{
123149
Wrapped: fakeAuthorizer,

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ require (
477477
)
478478

479479
require (
480+
github.com/brianvoe/gofakeit/v7 v7.3.0
480481
github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225
481482
github.com/coder/aisdk-go v0.0.9
482483
github.com/coder/preview v1.0.3

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,8 @@ github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl
830830
github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
831831
github.com/bramvdbogaerde/go-scp v1.5.0 h1:a9BinAjTfQh273eh7vd3qUgmBC+bx+3TRDtkZWmIpzM=
832832
github.com/bramvdbogaerde/go-scp v1.5.0/go.mod h1:on2aH5AxaFb2G0N5Vsdy6B0Ml7k9HuHSwfo1y0QzAbQ=
833+
github.com/brianvoe/gofakeit/v7 v7.3.0 h1:TWStf7/lLpAjKw+bqwzeORo9jvrxToWEwp9b1J2vApQ=
834+
github.com/brianvoe/gofakeit/v7 v7.3.0/go.mod h1:QXuPeBw164PJCzCUZVmgpgHJ3Llj49jSLVkKPMtxtxA=
833835
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
834836
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
835837
github.com/bytecodealliance/wasmtime-go/v3 v3.0.2 h1:3uZCA/BLTIu+DqCfguByNMJa2HVHpXvjfy0Dy7g6fuA=

testutil/faker.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package testutil
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
7+
"github.com/brianvoe/gofakeit/v7"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
// Fake will populate any zero fields in the provided struct with fake data.
12+
// Non-zero fields will remain unchanged.
13+
// Usage:
14+
//
15+
// key := Fake(t, faker, database.APIKey{
16+
// TokenName: "keep-my-name",
17+
// })
18+
func Fake[T any](t *testing.T, faker *gofakeit.Faker, seed T) T {
19+
t.Helper()
20+
21+
var tmp T
22+
err := faker.Struct(&tmp)
23+
require.NoError(t, err, "failed to generate fake data for type %T", tmp)
24+
25+
mergeZero(&seed, tmp)
26+
return seed
27+
}
28+
29+
// mergeZero merges the fields of src into dst, but only if the field in dst is
30+
// currently the zero value.
31+
// Make sure `dst` is a pointer to a struct, otherwise the fields are not assignable.
32+
func mergeZero(dst any, src any) {
33+
srcv := reflect.ValueOf(src)
34+
if srcv.Kind() == reflect.Ptr {
35+
srcv = srcv.Elem()
36+
}
37+
remain := [][2]reflect.Value{
38+
{reflect.ValueOf(dst).Elem(), srcv},
39+
}
40+
41+
// Traverse the struct fields and set them only if they are currently zero.
42+
// This is a breadth-first traversal of the struct fields. Struct definitions
43+
// Should not be that deep, so we should not hit any stack overflow issues.
44+
for {
45+
if len(remain) == 0 {
46+
return
47+
}
48+
dv, sv := remain[0][0], remain[0][1]
49+
remain = remain[1:] //
50+
for i := 0; i < dv.NumField(); i++ {
51+
df := dv.Field(i)
52+
sf := sv.Field(i)
53+
if !df.CanSet() {
54+
continue
55+
}
56+
if df.IsZero() { // only write if currently zero
57+
df.Set(sf)
58+
continue
59+
}
60+
61+
if dv.Field(i).Kind() == reflect.Struct {
62+
// If the field is a struct, we need to traverse it as well.
63+
remain = append(remain, [2]reflect.Value{df, sf})
64+
}
65+
}
66+
}
67+
}

testutil/faker_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package testutil_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/brianvoe/gofakeit/v7"
7+
"github.com/google/uuid"
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/coder/coder/v2/coderd/database"
11+
"github.com/coder/coder/v2/testutil"
12+
)
13+
14+
type simpleStruct struct {
15+
ID uuid.UUID
16+
Name string
17+
Description string
18+
Age int `fake:"{number:18,60}"`
19+
}
20+
21+
type nestedStruct struct {
22+
Person simpleStruct
23+
Address string
24+
}
25+
26+
func TestFake(t *testing.T) {
27+
t.Parallel()
28+
29+
t.Run("Simple", func(t *testing.T) {
30+
t.Parallel()
31+
32+
faker := gofakeit.New(0)
33+
person := testutil.Fake(t, faker, simpleStruct{
34+
Name: "alice",
35+
})
36+
require.Equal(t, "alice", person.Name)
37+
require.NotEqual(t, uuid.Nil, person.ID)
38+
require.NotEmpty(t, person.Description)
39+
require.Greater(t, person.Age, 17, "Age should be greater than 17")
40+
require.Less(t, person.Age, 61, "Age should be less than 61")
41+
})
42+
43+
t.Run("Nested", func(t *testing.T) {
44+
t.Parallel()
45+
46+
faker := gofakeit.New(0)
47+
person := testutil.Fake(t, faker, nestedStruct{
48+
Person: simpleStruct{
49+
Name: "alice",
50+
},
51+
})
52+
require.Equal(t, "alice", person.Person.Name)
53+
require.NotEqual(t, uuid.Nil, person.Person.ID)
54+
require.NotEmpty(t, person.Person.Description)
55+
require.Greater(t, person.Person.Age, 17, "Age should be greater than 17")
56+
require.NotEmpty(t, person.Address)
57+
})
58+
59+
t.Run("DatabaseType", func(t *testing.T) {
60+
t.Parallel()
61+
62+
faker := gofakeit.New(0)
63+
id := uuid.New()
64+
key := testutil.Fake(t, faker, database.APIKey{
65+
UserID: id,
66+
TokenName: "keep-my-name",
67+
})
68+
require.Equal(t, id, key.UserID)
69+
require.NotEmpty(t, key.TokenName)
70+
})
71+
}

0 commit comments

Comments
 (0)