Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
chore: refactoring; metrics do not need to be passed in Acquire
Signed-off-by: Danny Kopping <danny@coder.com>
  • Loading branch information
dannykopping committed Nov 24, 2025
commit efd54e76fb03e5adec0c33ffd491784941dcf984
9 changes: 2 additions & 7 deletions enterprise/aibridged/aibridged.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ import (

"golang.org/x/xerrors"

"github.com/prometheus/client_golang/prometheus"

"cdr.dev/slog"
"github.com/coder/aibridge"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/retry"
)
Expand All @@ -34,7 +31,6 @@ type Server struct {

// A pool of [aibridge.RequestBridge] instances, which service incoming requests.
requestBridgePool Pooler
metrics *aibridge.Metrics

logger slog.Logger
wg sync.WaitGroup
Expand All @@ -52,7 +48,7 @@ type Server struct {
shutdownOnce sync.Once
}

func New(ctx context.Context, pool Pooler, rpcDialer Dialer, reg prometheus.Registerer, logger slog.Logger) (*Server, error) {
func New(ctx context.Context, pool Pooler, rpcDialer Dialer, logger slog.Logger) (*Server, error) {
if rpcDialer == nil {
return nil, xerrors.Errorf("nil rpcDialer given")
}
Expand All @@ -67,7 +63,6 @@ func New(ctx context.Context, pool Pooler, rpcDialer Dialer, reg prometheus.Regi
initConnectionCh: make(chan struct{}),

requestBridgePool: pool,
metrics: aibridge.NewMetrics(reg),
}

daemon.wg.Add(1)
Expand Down Expand Up @@ -148,7 +143,7 @@ func (s *Server) GetRequestHandler(ctx context.Context, req Request) (http.Handl
return nil, xerrors.New("nil requestBridgePool")
}

reqBridge, err := s.requestBridgePool.Acquire(ctx, req, s.Client, NewMCPProxyFactory(s.logger, s.Client), s.metrics)
reqBridge, err := s.requestBridgePool.Acquire(ctx, req, s.Client, NewMCPProxyFactory(s.logger, s.Client))
if err != nil {
return nil, xerrors.Errorf("acquire request bridge: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions enterprise/aibridged/aibridged_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ func TestIntegration(t *testing.T) {

logger := testutil.Logger(t)
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{BaseURL: mockOpenAI.URL})}
pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, logger)
pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, nil, logger)
require.NoError(t, err)

// Given: aibridged is started.
srv, err := aibridged.New(t.Context(), pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return aiBridgeClient, nil
}, nil, logger)
}, logger)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(ctx)
Expand Down
9 changes: 4 additions & 5 deletions enterprise/aibridged/aibridged_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ func newTestServer(t *testing.T) (*aibridged.Server, *mock.MockDRPCClient, *mock
pool,
func(ctx context.Context) (aibridged.DRPCClient, error) {
return client, nil
},
nil, logger)
}, logger)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
srv.Shutdown(context.Background())
Expand Down Expand Up @@ -123,7 +122,7 @@ func TestServeHTTP_FailureModes(t *testing.T) {
// Should pass authorization.
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).AnyTimes().Return(&proto.IsAuthorizedResponse{OwnerId: uuid.NewString()}, nil)
// But fail when acquiring a pool instance.
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil, xerrors.New("oops"))
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil, xerrors.New("oops"))
},
expectedErr: aibridged.ErrAcquireRequestHandler,
expectedStatus: http.StatusInternalServerError,
Expand Down Expand Up @@ -291,7 +290,7 @@ func TestRouting(t *testing.T) {
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{BaseURL: openaiSrv.URL}),
aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{BaseURL: antSrv.URL}, nil),
}
pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, logger)
pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, nil, logger)
require.NoError(t, err)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
Expand All @@ -310,7 +309,7 @@ func TestRouting(t *testing.T) {
// Given: aibridged is started.
srv, err := aibridged.New(t.Context(), pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return client, nil
}, nil, logger)
}, logger)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(testutil.Context(t, testutil.WaitShort))
Expand Down
9 changes: 4 additions & 5 deletions enterprise/aibridged/aibridgedmock/poolmock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 9 additions & 5 deletions enterprise/aibridged/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const (
// Pooler describes a pool of [*aibridge.RequestBridge] instances from which instances can be retrieved.
// One [*aibridge.RequestBridge] instance is created per given key.
type Pooler interface {
Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder, metrics *aibridge.Metrics) (http.Handler, error)
Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error)
Shutdown(ctx context.Context) error
}

Expand Down Expand Up @@ -51,11 +51,13 @@ type CachedBridgePool struct {

singleflight *singleflight.Group[string, *aibridge.RequestBridge]

metrics *aibridge.Metrics

shutDownOnce sync.Once
shuttingDownCh chan struct{}
}

func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, logger slog.Logger) (*CachedBridgePool, error) {
func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, metrics *aibridge.Metrics, logger slog.Logger) (*CachedBridgePool, error) {
cache, err := ristretto.NewCache(&ristretto.Config[string, *aibridge.RequestBridge]{
NumCounters: options.MaxItems * 10, // Docs suggest setting this 10x number of keys.
MaxCost: options.MaxItems * cacheCost, // Up to n instances.
Expand Down Expand Up @@ -88,6 +90,8 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log

singleflight: &singleflight.Group[string, *aibridge.RequestBridge]{},

metrics: metrics,

shuttingDownCh: make(chan struct{}),
}, nil
}
Expand All @@ -96,7 +100,7 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log
//
// Each returned [*aibridge.RequestBridge] is safe for concurrent use.
// Each [*aibridge.RequestBridge] is stateful because it has MCP clients which maintain sessions to the configured MCP server.
func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpProxyFactory MCPProxyBuilder, metrics *aibridge.Metrics) (http.Handler, error) {
func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpProxyFactory MCPProxyBuilder) (http.Handler, error) {
if err := ctx.Err(); err != nil {
return nil, xerrors.Errorf("acquire: %w", err)
}
Expand Down Expand Up @@ -154,7 +158,7 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
}
}

bridge, err := aibridge.NewRequestBridge(ctx, p.providers, recorder, mcpServers, metrics, p.logger)
bridge, err := aibridge.NewRequestBridge(ctx, p.providers, recorder, mcpServers, p.metrics, p.logger)
if err != nil {
return nil, xerrors.Errorf("create new request bridge: %w", err)
}
Expand All @@ -167,7 +171,7 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
return instance, err
}

func (p *CachedBridgePool) Metrics() PoolMetrics {
func (p *CachedBridgePool) CacheMetrics() PoolMetrics {
if p.cache == nil {
return nil
}
Expand Down
40 changes: 20 additions & 20 deletions enterprise/aibridged/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestPool(t *testing.T) {
mcpProxy := mcpmock.NewMockServerProxier(ctrl)

opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Second}
pool, err := aibridged.NewCachedBridgePool(opts, nil, logger)
pool, err := aibridged.NewCachedBridgePool(opts, nil, nil, logger)
require.NoError(t, err)
t.Cleanup(func() { pool.Shutdown(context.Background()) })

Expand All @@ -51,23 +51,23 @@ func TestPool(t *testing.T) {
SessionKey: "key",
InitiatorID: id,
APIKeyID: apiKeyID1.String(),
}, clientFn, newMockMCPFactory(mcpProxy), nil)
}, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err, "acquire pool instance")

// ...and it will return it when acquired again.
instB, err := pool.Acquire(t.Context(), aibridged.Request{
SessionKey: "key",
InitiatorID: id,
APIKeyID: apiKeyID1.String(),
}, clientFn, newMockMCPFactory(mcpProxy), nil)
}, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err, "acquire pool instance")
require.Same(t, inst, instB)

metrics := pool.Metrics()
require.EqualValues(t, 1, metrics.KeysAdded())
require.EqualValues(t, 0, metrics.KeysEvicted())
require.EqualValues(t, 1, metrics.Hits())
require.EqualValues(t, 1, metrics.Misses())
cacheMetrics := pool.CacheMetrics()
require.EqualValues(t, 1, cacheMetrics.KeysAdded())
require.EqualValues(t, 0, cacheMetrics.KeysEvicted())
require.EqualValues(t, 1, cacheMetrics.Hits())
require.EqualValues(t, 1, cacheMetrics.Misses())

// This will get called again because a new instance will be created.
mcpProxy.EXPECT().Init(gomock.Any()).Times(1).Return(nil)
Expand All @@ -77,15 +77,15 @@ func TestPool(t *testing.T) {
SessionKey: "key",
InitiatorID: id2,
APIKeyID: apiKeyID1.String(),
}, clientFn, newMockMCPFactory(mcpProxy), nil)
}, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err, "acquire pool instance")
require.NotSame(t, inst, inst2)

metrics = pool.Metrics()
require.EqualValues(t, 2, metrics.KeysAdded())
require.EqualValues(t, 1, metrics.KeysEvicted())
require.EqualValues(t, 1, metrics.Hits())
require.EqualValues(t, 2, metrics.Misses())
cacheMetrics = pool.CacheMetrics()
require.EqualValues(t, 2, cacheMetrics.KeysAdded())
require.EqualValues(t, 1, cacheMetrics.KeysEvicted())
require.EqualValues(t, 1, cacheMetrics.Hits())
require.EqualValues(t, 2, cacheMetrics.Misses())

// This will get called again because a new instance will be created.
mcpProxy.EXPECT().Init(gomock.Any()).Times(1).Return(nil)
Expand All @@ -95,15 +95,15 @@ func TestPool(t *testing.T) {
SessionKey: "key",
InitiatorID: id2,
APIKeyID: apiKeyID2.String(),
}, clientFn, newMockMCPFactory(mcpProxy), nil)
}, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err, "acquire pool instance 2B")
require.NotSame(t, inst2, inst2B)

metrics = pool.Metrics()
require.EqualValues(t, 3, metrics.KeysAdded())
require.EqualValues(t, 2, metrics.KeysEvicted())
require.EqualValues(t, 1, metrics.Hits())
require.EqualValues(t, 3, metrics.Misses())
cacheMetrics = pool.CacheMetrics()
require.EqualValues(t, 3, cacheMetrics.KeysAdded())
require.EqualValues(t, 2, cacheMetrics.KeysEvicted())
require.EqualValues(t, 1, cacheMetrics.Hits())
require.EqualValues(t, 3, cacheMetrics.Misses())

// TODO: add test for expiry.
// This requires Go 1.25's [synctest](https://pkg.go.dev/testing/synctest) since the
Expand Down
7 changes: 5 additions & 2 deletions enterprise/cli/aibridged.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,19 @@ func newAIBridgeDaemon(coderAPI *coderd.API) (*aibridged.Server, error) {
}, getBedrockConfig(coderAPI.DeploymentValues.AI.BridgeConfig.Bedrock)),
}

reg := prometheus.WrapRegistererWithPrefix("coder_aibridged_", coderAPI.PrometheusRegistry)
metrics := aibridge.NewMetrics(reg)

// Create pool for reusable stateful [aibridge.RequestBridge] instances (one per user).
pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, logger.Named("pool")) // TODO: configurable.
pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, metrics, logger.Named("pool")) // TODO: configurable size.
if err != nil {
return nil, xerrors.Errorf("create request pool: %w", err)
}

// Create daemon.
srv, err := aibridged.New(ctx, pool, func(dialCtx context.Context) (aibridged.DRPCClient, error) {
return coderAPI.CreateInMemoryAIBridgeServer(dialCtx)
}, prometheus.WrapRegistererWithPrefix("coder_aibridged_", coderAPI.PrometheusRegistry), logger)
}, logger)
if err != nil {
return nil, xerrors.Errorf("start in-memory aibridge daemon: %w", err)
}
Expand Down