Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"cdr.dev/slog"
"github.com/coder/aibridge/mcp"
"go.opentelemetry.io/otel/trace"

"github.com/hashicorp/go-multierror"
)
Expand Down Expand Up @@ -47,20 +48,20 @@ var _ http.Handler = &RequestBridge{}
// A [Recorder] is also required to record prompt, tool, and token use.
//
// mcpProxy will be closed when the [RequestBridge] is closed.
func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, logger slog.Logger) (*RequestBridge, error) {
func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) (*RequestBridge, error) {
mux := http.NewServeMux()

for _, provider := range providers {
// Add the known provider-specific routes which are bridged (i.e. intercepted and augmented).
for _, path := range provider.BridgedRoutes() {
mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics))
mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer))
}

// Any requests which passthrough to this will be reverse-proxied to the upstream.
//
// We have to whitelist the known-safe routes because an API key with elevated privileges (i.e. admin) might be
// configured, so we should just reverse-proxy known-safe routes.
ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics)
ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics, tracer)
for _, path := range provider.PassthroughRoutes() {
prefix := fmt.Sprintf("/%s", provider.Name())
route := fmt.Sprintf("%s%s", prefix, path)
Expand Down
115 changes: 62 additions & 53 deletions bridge_integration_test.go

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ require (
github.com/openai/openai-go/v2 v2.7.0
)

// Tracing-related libs.
require (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
require (
// Tracing-related libs.
require (

Also, thanks for getting rid of go-cmp 👍

go.opentelemetry.io/otel v1.38.0
go.opentelemetry.io/otel/sdk v1.38.0
go.opentelemetry.io/otel/trace v1.38.0
)

require (
github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect
Expand All @@ -46,6 +53,8 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/lipgloss v0.7.1 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
Expand All @@ -61,14 +70,13 @@ require (
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/rivo/uniseg v0.4.4 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/spf13/cast v1.7.1 // indirect
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opentelemetry.io/otel v1.33.0 // indirect
go.opentelemetry.io/otel/trace v1.33.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/term v0.34.0 // indirect
Expand Down
23 changes: 13 additions & 10 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
Expand Down Expand Up @@ -130,14 +131,16 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/otel v1.33.0 h1:/FerN9bax5LoK51X/sI0SVYrjSE0/yUL7DpxW4K3FWw=
go.opentelemetry.io/otel v1.33.0/go.mod h1:SUUkR6csvUQl+yjReHu5uM3EtVV7MBm5FHKRlNx4I8I=
go.opentelemetry.io/otel/metric v1.33.0 h1:r+JOocAyeRVXD8lZpjdQjzMadVZp2M4WmQ+5WtEnklQ=
go.opentelemetry.io/otel/metric v1.33.0/go.mod h1:L9+Fyctbp6HFTddIxClbQkjtubW6O9QS3Ann/M82u6M=
go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiMWgE=
go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4=
go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s=
go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
Expand Down
16 changes: 16 additions & 0 deletions intercept_anthropic_messages_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ import (
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/tracing"
"github.com/google/uuid"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"

"cdr.dev/slog"
)
Expand All @@ -27,6 +30,7 @@ type AnthropicMessagesInterceptionBase struct {
cfg AnthropicConfig
bedrockCfg *AWSBedrockConfig

tracer trace.Tracer
logger slog.Logger

recorder Recorder
Expand Down Expand Up @@ -59,6 +63,18 @@ func (i *AnthropicMessagesInterceptionBase) Model() string {
return string(i.req.Model)
}

func (s *AnthropicMessagesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
return []attribute.KeyValue{
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, s.id.String()),
attribute.String(tracing.InitiatorID, actorFromContext(r.Context()).id),
attribute.String(tracing.Provider, ProviderAnthropic),
attribute.String(tracing.Model, s.Model()),
attribute.Bool(tracing.Streaming, streaming),
attribute.Bool(tracing.IsBedrock, s.bedrockCfg != nil),
}
}

func (i *AnthropicMessagesInterceptionBase) injectTools() {
if i.req == nil || i.mcpProxy == nil {
return
Expand Down
32 changes: 25 additions & 7 deletions intercept_anthropic_messages_blocking.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package aibridge

import (
"context"
"fmt"
"net/http"
"time"
Expand All @@ -10,8 +11,11 @@ import (
"github.com/google/uuid"
mcplib "github.com/mark3labs/mcp-go/mcp" // TODO: abstract this away so callers need no knowledge of underlying lib.
"github.com/tidwall/sjson"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"

"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/tracing"

"cdr.dev/slog"
)
Expand All @@ -22,29 +26,35 @@ type AnthropicMessagesBlockingInterception struct {
AnthropicMessagesInterceptionBase
}

func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicMessagesBlockingInterception {
func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig, tracer trace.Tracer) *AnthropicMessagesBlockingInterception {
return &AnthropicMessagesBlockingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{
id: id,
req: req,
cfg: cfg,
bedrockCfg: bedrockCfg,
tracer: tracer,
}}
}

func (s *AnthropicMessagesBlockingInterception) Setup(logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) {
s.AnthropicMessagesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy)
func (i *AnthropicMessagesBlockingInterception) Setup(logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) {
i.AnthropicMessagesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy)
}

func (i *AnthropicMessagesBlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return i.AnthropicMessagesInterceptionBase.baseTraceAttributes(r, false)
}

func (s *AnthropicMessagesBlockingInterception) Streaming() bool {
return false
}

func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error {
func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) {
if i.req == nil {
return fmt.Errorf("developer error: req is nil")
}

ctx := r.Context()
ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...))
defer tracing.EndSpanErr(span, &outErr)

i.injectTools()

Expand Down Expand Up @@ -77,7 +87,8 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr
var cumulativeUsage anthropic.Usage

for {
resp, err = svc.New(ctx, messages)
// TODO add outer loop span (https://github.com/coder/aibridge/issues/67)
resp, err = i.newMessage(ctx, svc, messages)
if err != nil {
if isConnError(err) {
// Can't write a response, just error out.
Expand Down Expand Up @@ -166,7 +177,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr
continue
}

res, err := tool.Call(ctx, tc.Input)
res, err := tool.Call(ctx, tc.Input, i.tracer)

_ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{
InterceptionID: i.ID().String(),
Expand Down Expand Up @@ -286,3 +297,10 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr

return nil
}

func (i *AnthropicMessagesBlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService, msgParams anthropic.MessageNewParams) (_ *anthropic.Message, outErr error) {
ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer tracing.EndSpanErr(span, &outErr)

return svc.New(ctx, msgParams)
}
31 changes: 26 additions & 5 deletions intercept_anthropic_messages_streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ import (
"time"

"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
"github.com/anthropics/anthropic-sdk-go/shared/constant"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/tracing"
"github.com/google/uuid"
mcplib "github.com/mark3labs/mcp-go/mcp"
"github.com/tidwall/sjson"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"

"cdr.dev/slog"
)
Expand All @@ -26,12 +30,13 @@ type AnthropicMessagesStreamingInterception struct {
AnthropicMessagesInterceptionBase
}

func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicMessagesStreamingInterception {
func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig, tracer trace.Tracer) *AnthropicMessagesStreamingInterception {
return &AnthropicMessagesStreamingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{
id: id,
req: req,
cfg: cfg,
bedrockCfg: bedrockCfg,
tracer: tracer,
}}
}

Expand All @@ -43,6 +48,10 @@ func (s *AnthropicMessagesStreamingInterception) Streaming() bool {
return true
}

func (s *AnthropicMessagesStreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return s.AnthropicMessagesInterceptionBase.baseTraceAttributes(r, true)
}

// ProcessRequest handles a request to /v1/messages.
// This API has a state-machine behind it, which is described in https://docs.claude.com/en/docs/build-with-claude/streaming#event-types.
//
Expand All @@ -62,13 +71,16 @@ func (s *AnthropicMessagesStreamingInterception) Streaming() bool {
// b) if the tool is injected, it will be invoked by the [mcp.ServerProxier] in the remote MCP server, and its
// results relayed to the SERVER. The response from the server will be handled synchronously, and this loop
// can continue until all injected tool invocations are completed and the response is relayed to the client.
func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error {
func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) {
if i.req == nil {
return fmt.Errorf("developer error: req is nil")
}

ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...))
defer tracing.EndSpanErr(span, &outErr)

// Allow us to interrupt watch via cancel.
ctx, cancel := context.WithCancel(r.Context())
ctx, cancel := context.WithCancel(ctx)
defer cancel()
r = r.WithContext(ctx) // Rewire context for SSE cancellation.

Expand Down Expand Up @@ -118,12 +130,13 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW
isFirst := true
newStream:
for {
// TODO add outer loop span (https://github.com/coder/aibridge/issues/67)
if err := streamCtx.Err(); err != nil {
lastErr = fmt.Errorf("stream exit: %w", err)
break
}

stream := svc.NewStreaming(streamCtx, messages)
stream := i.newStream(streamCtx, svc, messages)

var message anthropic.Message
var lastToolName string
Expand Down Expand Up @@ -270,7 +283,7 @@ newStream:
continue
}

res, err := tool.Call(streamCtx, input)
res, err := tool.Call(streamCtx, input, i.tracer)

_ = i.recorder.RecordToolUsage(streamCtx, &ToolUsageRecord{
InterceptionID: i.ID().String(),
Expand Down Expand Up @@ -522,3 +535,11 @@ func (s *AnthropicMessagesStreamingInterception) encodeForStream(payload []byte,
buf.WriteString("\n\n")
return buf.Bytes()
}

// newStream traces svc.NewStreaming(streamCtx, messages)
func (s *AnthropicMessagesStreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService, messages anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEventUnion] {
_, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer span.End()

return svc.NewStreaming(ctx, messages)
}
24 changes: 20 additions & 4 deletions intercept_openai_chat_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,25 @@ import (
"strings"

"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/tracing"
"github.com/google/uuid"
"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/option"
"github.com/openai/openai-go/v2/shared"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"

"cdr.dev/slog"
)

type OpenAIChatInterceptionBase struct {
id uuid.UUID
req *ChatCompletionNewParamsWrapper
id uuid.UUID
req *ChatCompletionNewParamsWrapper
baseURL string
key string

baseURL, key string
logger slog.Logger
logger slog.Logger
tracer trace.Tracer

recorder Recorder
mcpProxy mcp.ServerProxier
Expand All @@ -42,6 +47,17 @@ func (i *OpenAIChatInterceptionBase) Setup(logger slog.Logger, recorder Recorder
i.mcpProxy = mcpProxy
}

func (s *OpenAIChatInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
return []attribute.KeyValue{
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, s.id.String()),
attribute.String(tracing.InitiatorID, actorFromContext(r.Context()).id),
attribute.String(tracing.Provider, ProviderOpenAI),
attribute.String(tracing.Model, s.Model()),
attribute.Bool(tracing.Streaming, streaming),
}
}

func (i *OpenAIChatInterceptionBase) Model() string {
if i.req == nil {
return "coder-aibridge-unknown"
Expand Down
Loading