From 9d3e2d9c112bd65485cef41b85f253cb23504c87 Mon Sep 17 00:00:00 2001 From: Michael Smith Date: Fri, 7 Mar 2025 22:34:20 +0000 Subject: [PATCH 1/7] chore: add support for one-way websockets to backend --- coderd/coderd.go | 14 +- coderd/httpapi/httpapi.go | 125 +++++++++++++++- coderd/httpapi/httpapi_test.go | 265 +++++++++++++++++++++++++++++++++ coderd/httpapi/websocket.go | 9 +- coderd/workspaceagents.go | 34 ++++- coderd/workspaces.go | 41 +++-- 6 files changed, 456 insertions(+), 32 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index ab8e99d29dea8..02d814be46ea9 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -823,7 +823,7 @@ func New(options *Options) *API { // we do not override subdomain app routes. r.Get("/latency-check", tracing.StatusWriterMiddleware(prometheusMW(LatencyCheck())).ServeHTTP) - r.Get("/healthz", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("OK")) }) + r.Get("/healthz", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("OK")) }) // Attach workspace apps routes. r.Group(func(r chi.Router) { @@ -838,7 +838,7 @@ func New(options *Options) *API { r.Route("/derp", func(r chi.Router) { r.Get("/", derpHandler.ServeHTTP) // This is used when UDP is blocked, and latency must be checked via HTTP(s). - r.Get("/latency-check", func(w http.ResponseWriter, r *http.Request) { + r.Get("/latency-check", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) }) @@ -895,7 +895,7 @@ func New(options *Options) *API { r.Route("/api/v2", func(r chi.Router) { api.APIHandler = r - r.NotFound(func(rw http.ResponseWriter, r *http.Request) { httpapi.RouteNotFound(rw) }) + r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) }) r.Use( // Specific routes can specify different limits, but every rate // limit must be configurable by the admin. @@ -1230,7 +1230,8 @@ func New(options *Options) *API { httpmw.ExtractWorkspaceParam(options.Database), ) r.Get("/", api.workspaceAgent) - r.Get("/watch-metadata", api.watchWorkspaceAgentMetadata) + r.Get("/watch-metadata", api.watchWorkspaceAgentMetadataSSE) + r.Get("/watch-metadata-ws", api.watchWorkspaceAgentMetadataWS) r.Get("/startup-logs", api.workspaceAgentLogsDeprecated) r.Get("/logs", api.workspaceAgentLogs) r.Get("/listening-ports", api.workspaceAgentListeningPorts) @@ -1262,7 +1263,8 @@ func New(options *Options) *API { r.Route("/ttl", func(r chi.Router) { r.Put("/", api.putWorkspaceTTL) }) - r.Get("/watch", api.watchWorkspace) + r.Get("/watch", api.watchWorkspaceSSE) + r.Get("/watch-ws", api.watchWorkspaceWS) r.Put("/extend", api.putExtendWorkspace) r.Post("/usage", api.postWorkspaceUsage) r.Put("/dormant", api.putWorkspaceDormant) @@ -1408,7 +1410,7 @@ func New(options *Options) *API { // global variable here. r.Get("/swagger/*", globalHTTPSwaggerHandler) } else { - swaggerDisabled := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + swaggerDisabled := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { httpapi.Write(context.Background(), rw, http.StatusNotFound, codersdk.Response{ Message: "Swagger documentation is disabled.", }) diff --git a/coderd/httpapi/httpapi.go b/coderd/httpapi/httpapi.go index d5895dcbf86f0..cdae4a01a636f 100644 --- a/coderd/httpapi/httpapi.go +++ b/coderd/httpapi/httpapi.go @@ -16,6 +16,9 @@ import ( "github.com/go-playground/validator/v10" "golang.org/x/xerrors" + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/coder/coder/v2/coderd/httpapi/httpapiconstraints" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/codersdk" @@ -282,7 +285,25 @@ func WebsocketCloseSprintf(format string, vars ...any) string { return msg } -func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent func(ctx context.Context, sse codersdk.ServerSentEvent) error, closed chan struct{}, err error) { +type InitializeConnectionCallback func(rw http.ResponseWriter, r *http.Request) ( + sendEvent func(sse codersdk.ServerSentEvent) error, + done <-chan struct{}, + err error, +) + +// ServerSentEventSender establishes a Server-Sent Event connection and allows +// the consumer to send messages to the client. +// +// The function returned allows you to send a single message to the client, +// while the channel lets you listen for when the connection closes. +// +// As much as possible, this function should be avoided in favor of using the +// OneWayWebSocket function. See OneWayWebSocket for more context. +func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) ( + func(sse codersdk.ServerSentEvent) error, + <-chan struct{}, + error, +) { h := rw.Header() h.Set("Content-Type", "text/event-stream") h.Set("Cache-Control", "no-cache") @@ -294,7 +315,8 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f panic("http.ResponseWriter is not http.Flusher") } - closed = make(chan struct{}) + ctx := r.Context() + closed := make(chan struct{}) type sseEvent struct { payload []byte errC chan error @@ -333,21 +355,21 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f } }() - sendEvent = func(ctx context.Context, sse codersdk.ServerSentEvent) error { + sendEvent := func(newEvent codersdk.ServerSentEvent) error { buf := &bytes.Buffer{} enc := json.NewEncoder(buf) - _, err := buf.WriteString(fmt.Sprintf("event: %s\n", sse.Type)) + _, err := buf.WriteString(fmt.Sprintf("event: %s\n", newEvent.Type)) if err != nil { return err } - if sse.Data != nil { + if newEvent.Data != nil { _, err = buf.WriteString("data: ") if err != nil { return err } - err = enc.Encode(sse.Data) + err = enc.Encode(newEvent.Data) if err != nil { return err } @@ -387,3 +409,94 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent f return sendEvent, closed, nil } + +// OneWayWebSocket establishes a new WebSocket connection that enforces one-way +// communication from the server to the client. +// +// The function returned allows you to send a single message to the client, +// while the channel lets you listen for when the connection closes. +// +// We must use an approach like this instead of Server-Sent Events for the +// browser, because on HTTP/1.1 connections, browsers are locked to no more than +// six HTTP connections for a domain total, across all tabs. If a user were to +// open a workspace in multiple tabs, the entire UI can start to lock up. +// WebSockets have no such limitation, no matter what HTTP protocol was used to +// establish the connection. +func OneWayWebSocket(rw http.ResponseWriter, r *http.Request) ( + func(event codersdk.ServerSentEvent) error, + <-chan struct{}, + error, +) { + ctx, cancel := context.WithCancel(r.Context()) + r = r.WithContext(ctx) + socket, err := websocket.Accept(rw, r, nil) + if err != nil { + cancel() + return nil, nil, xerrors.Errorf("cannot establish connection: %w", err) + } + go Heartbeat(ctx, socket) + + type SocketError struct { + Code websocket.StatusCode + Reason string + } + eventC := make(chan codersdk.ServerSentEvent) + socketErrC := make(chan SocketError, 1) + closed := make(chan struct{}) + go func() { + defer cancel() + defer close(closed) + + for { + select { + case event := <-eventC: + writeCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + err := wsjson.Write(writeCtx, socket, event) + cancel() + if err == nil { + continue + } + _ = socket.Close(websocket.StatusInternalError, "Unable to send newest message") + case err := <-socketErrC: + _ = socket.Close(err.Code, err.Reason) + case <-ctx.Done(): + _ = socket.Close(websocket.StatusNormalClosure, "Connection closed") + } + return + } + }() + + // We have some tools in the UI code to help enforce one-way WebSocket + // connections, but there's still the possibility that the client could send + // a message when it's not supposed to. If that happens, the client likely + // forgot to use those tools, and communication probably can't be trusted. + // Better to just close the socket and force the UI to fix its mess + go func() { + _, _, err := socket.Read(ctx) + if errors.Is(err, context.Canceled) { + return + } + if err != nil { + socketErrC <- SocketError{ + Code: websocket.StatusInternalError, + Reason: "Unable to process invalid message from client", + } + return + } + socketErrC <- SocketError{ + Code: websocket.StatusProtocolError, + Reason: "Clients cannot send messages for one-way WebSockets", + } + }() + + sendEvent := func(event codersdk.ServerSentEvent) error { + select { + case eventC <- event: + case <-ctx.Done(): + return ctx.Err() + } + return nil + } + + return sendEvent, closed, nil +} diff --git a/coderd/httpapi/httpapi_test.go b/coderd/httpapi/httpapi_test.go index eb3f23e6ca346..655704b0ed1ff 100644 --- a/coderd/httpapi/httpapi_test.go +++ b/coderd/httpapi/httpapi_test.go @@ -1,14 +1,18 @@ package httpapi_test import ( + "bufio" "bytes" "context" "encoding/json" "fmt" + "io" + "net" "net/http" "net/http/httptest" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,6 +20,7 @@ import ( "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" ) func TestInternalServerError(t *testing.T) { @@ -155,3 +160,263 @@ func TestWebsocketCloseMsg(t *testing.T) { assert.Equal(t, len(trunc), 123) }) } + +// Our WebSocket library accepts any arbitrary ResponseWriter at the type level, +// but it must also implement http.Hijack +type mockWsResponseWriter struct { + serverRecorder *httptest.ResponseRecorder + serverConn net.Conn + clientConn net.Conn + serverReadWriter *bufio.ReadWriter +} + +func (m mockWsResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return m.serverConn, m.serverReadWriter, nil +} + +func (m mockWsResponseWriter) Flush() { + _ = m.serverReadWriter.Flush() +} + +func (m mockWsResponseWriter) Header() http.Header { + return m.serverRecorder.Header() +} + +func (m mockWsResponseWriter) Write(b []byte) (int, error) { + return m.serverReadWriter.Write(b) +} + +func (m mockWsResponseWriter) WriteHeader(code int) { + m.serverRecorder.WriteHeader(code) +} + +type mockWsWrite func(b []byte) (int, error) + +func (w mockWsWrite) Write(b []byte) (int, error) { + return w(b) +} + +func TestOneWayWebSocket(t *testing.T) { + t.Parallel() + + newBaseRequest := func(ctx context.Context) *http.Request { + url := "ws://www.fake-website.com/logs" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + h := req.Header + h.Add("Connection", "Upgrade") + h.Add("Upgrade", "websocket") + h.Add("Sec-WebSocket-Version", "13") + h.Add("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") // Just need any string + + return req + } + + newWebsocketWriter := func() mockWsResponseWriter { + mockServer, mockClient := net.Pipe() + recorder := httptest.NewRecorder() + + var write mockWsWrite = func(b []byte) (int, error) { + serverCount, err := mockServer.Write(b) + if err != nil { + return serverCount, err + } + recorderCount, err := recorder.Write(b) + return min(serverCount, recorderCount), err + } + + return mockWsResponseWriter{ + serverConn: mockServer, + clientConn: mockClient, + serverRecorder: recorder, + serverReadWriter: bufio.NewReadWriter( + bufio.NewReader(mockServer), + bufio.NewWriter(write), + ), + } + } + + t.Run("Produces error if the socket connection could not be established", func(t *testing.T) { + t.Parallel() + + incorrectProtocols := []struct { + major int + minor int + proto string + }{ + {0, 9, "HTTP/0.9"}, + {1, 0, "HTTP/1.0"}, + } + for _, p := range incorrectProtocols { + ctx := testutil.Context(t, testutil.WaitShort) + req := newBaseRequest(ctx) + req.ProtoMajor = p.major + req.ProtoMinor = p.minor + req.Proto = p.proto + + writer := newWebsocketWriter() + _, _, err := httpapi.OneWayWebSocket(writer, req) + require.ErrorContains(t, err, p.proto) + } + }) + + t.Run("Returned callback can publish new event to WebSocket connection", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + req := newBaseRequest(ctx) + writer := newWebsocketWriter() + send, _, err := httpapi.OneWayWebSocket(writer, req) + require.NoError(t, err) + + serverPayload := codersdk.ServerSentEvent{ + Type: codersdk.ServerSentEventTypeData, + Data: "Blah", + } + err = send(serverPayload) + require.NoError(t, err) + + // The client connection will receive a little bit of additional data on + // top of the main payload. Have to make sure check has tolerance for + // extra data being present + serverBytes, err := json.Marshal(serverPayload) + require.NoError(t, err) + clientBytes, err := io.ReadAll(writer.clientConn) + require.NoError(t, err) + require.True(t, bytes.Contains(clientBytes, serverBytes)) + }) + + t.Run("Signals to outside consumer when socket has been closed", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) + req := newBaseRequest(ctx) + writer := newWebsocketWriter() + _, done, err := httpapi.OneWayWebSocket(writer, req) + require.NoError(t, err) + + successC := make(chan bool) + ticker := time.NewTicker(testutil.WaitShort) + go func() { + select { + case <-done: + successC <- true + case <-ticker.C: + successC <- false + } + }() + + cancel() + require.True(t, <-successC) + }) + + t.Run("Socket will immediately close if client sends any message", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + req := newBaseRequest(ctx) + writer := newWebsocketWriter() + _, done, err := httpapi.OneWayWebSocket(writer, req) + require.NoError(t, err) + + successC := make(chan bool) + ticker := time.NewTicker(testutil.WaitShort) + go func() { + select { + case <-done: + successC <- true + case <-ticker.C: + successC <- false + } + }() + + type JunkClientEvent struct { + Value string + } + b, err := json.Marshal(JunkClientEvent{"Hi :)"}) + require.NoError(t, err) + _, err = writer.clientConn.Write(b) + require.NoError(t, err) + require.True(t, <-successC) + }) + + t.Run("Renders the socket inert if the request context cancels", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) + req := newBaseRequest(ctx) + writer := newWebsocketWriter() + send, done, err := httpapi.OneWayWebSocket(writer, req) + require.NoError(t, err) + + successC := make(chan bool) + ticker := time.NewTicker(testutil.WaitShort) + go func() { + select { + case <-done: + successC <- true + case <-ticker.C: + successC <- false + } + }() + + cancel() + require.True(t, <-successC) + err = send(codersdk.ServerSentEvent{ + Type: codersdk.ServerSentEventTypeData, + Data: "Didn't realize you were closed - sorry! I'll try coming back tomorrow.", + }) + require.Equal(t, err, ctx.Err()) + _, open := <-done + require.False(t, open) + _, err = writer.serverConn.Write([]byte{}) + require.Equal(t, err, io.ErrClosedPipe) + _, err = writer.clientConn.Read([]byte{}) + require.Equal(t, err, io.EOF) + }) + + t.Run("Sends a heartbeat to the socket on a fixed internal of time to keep connections alive", func(t *testing.T) { + t.Parallel() + + // Need add at least three heartbeats for something to be reliably + // counted as an interval, but also need some wiggle room + heartbeatCount := 3 + hbDuration := time.Duration(heartbeatCount) * httpapi.HeartbeatInterval + timeout := hbDuration + (5 * time.Second) + + ctx := testutil.Context(t, timeout) + req := newBaseRequest(ctx) + writer := newWebsocketWriter() + _, _, err := httpapi.OneWayWebSocket(writer, req) + require.NoError(t, err) + + type Result struct { + Err error + Success bool + } + resultC := make(chan Result) + go func() { + err := writer. + clientConn. + SetReadDeadline(time.Now().Add(timeout)) + if err != nil { + resultC <- Result{err, false} + return + } + for range heartbeatCount { + pingBuffer := make([]byte, 1) + pingSize, err := writer.clientConn.Read(pingBuffer) + if err != nil || pingSize != 1 { + resultC <- Result{err, false} + return + } + } + resultC <- Result{nil, true} + }() + + result := <-resultC + require.NoError(t, result.Err) + require.True(t, result.Success) + }) +} diff --git a/coderd/httpapi/websocket.go b/coderd/httpapi/websocket.go index 20c780f6bffa0..3a71c9c9ae8b0 100644 --- a/coderd/httpapi/websocket.go +++ b/coderd/httpapi/websocket.go @@ -11,11 +11,13 @@ import ( "github.com/coder/websocket" ) +const HeartbeatInterval time.Duration = 15 * time.Second + // Heartbeat loops to ping a WebSocket to keep it alive. // Default idle connection timeouts are typically 60 seconds. // See: https://docs.aws.amazon.com/elasticloadbalancing/latest/application/application-load-balancers.html#connection-idle-timeout func Heartbeat(ctx context.Context, conn *websocket.Conn) { - ticker := time.NewTicker(15 * time.Second) + ticker := time.NewTicker(HeartbeatInterval) defer ticker.Stop() for { select { @@ -33,8 +35,7 @@ func Heartbeat(ctx context.Context, conn *websocket.Conn) { // Heartbeat loops to ping a WebSocket to keep it alive. It calls `exit` on ping // failure. func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn *websocket.Conn) { - interval := 15 * time.Second - ticker := time.NewTicker(interval) + ticker := time.NewTicker(HeartbeatInterval) defer ticker.Stop() for { @@ -43,7 +44,7 @@ func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn * return case <-ticker.C: } - err := pingWithTimeout(ctx, conn, interval) + err := pingWithTimeout(ctx, conn, HeartbeatInterval) if err != nil { // context.DeadlineExceeded is expected when the client disconnects without sending a close frame if !errors.Is(err, context.DeadlineExceeded) { diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index ff16735af9aea..7234fda2cb75a 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -1094,7 +1094,29 @@ func convertScripts(dbScripts []database.WorkspaceAgentScript) []codersdk.Worksp // @Param workspaceagent path string true "Workspace agent ID" format(uuid) // @Router /workspaceagents/{workspaceagent}/watch-metadata [get] // @x-apidocgen {"skip": true} -func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) { +// @Deprecated Use /workspaceagents/{workspaceagent}/watch-metadata-ws instead +func (api *API) watchWorkspaceAgentMetadataSSE(rw http.ResponseWriter, r *http.Request) { + api.watchWorkspaceAgentMetadata(rw, r, httpapi.ServerSentEventSender) +} + +// @Summary Watch for workspace agent metadata updates via WebSockets +// @ID watch-for-workspace-agent-metadata-updates-via-websockets +// @Security CoderSessionToken +// @Produce json +// @Tags Agents +// @Success 200 {object} codersdk.ServerSentEvent +// @Param workspaceagent path string true "Workspace agent ID" format(uuid) +// @Router /workspaceagents/{workspaceagent}/watch-metadata-ws [get] +// @x-apidocgen {"skip": true} +func (api *API) watchWorkspaceAgentMetadataWS(rw http.ResponseWriter, r *http.Request) { + api.watchWorkspaceAgentMetadata(rw, r, httpapi.OneWayWebSocket) +} + +func (api *API) watchWorkspaceAgentMetadata( + rw http.ResponseWriter, + r *http.Request, + connect httpapi.InitializeConnectionCallback, +) { // Allow us to interrupt watch via cancel. ctx, cancel := context.WithCancel(r.Context()) defer cancel() @@ -1159,7 +1181,7 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ //nolint:ineffassign // Release memory. initialMD = nil - sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(rw, r) + sendEvent, senderClosed, err := connect(rw, r) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error setting up server-sent events.", @@ -1170,14 +1192,14 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ // Prevent handler from returning until the sender is closed. defer func() { cancel() - <-sseSenderClosed + <-senderClosed }() // Synchronize cancellation from SSE -> context, this lets us simplify the // cancellation logic. go func() { select { case <-ctx.Done(): - case <-sseSenderClosed: + case <-senderClosed: cancel() } }() @@ -1189,7 +1211,7 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ log.Debug(ctx, "sending metadata", "num", len(values)) - _ = sseSendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeData, Data: convertWorkspaceAgentMetadata(values), }) @@ -1221,7 +1243,7 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ if err != nil { if !database.IsQueryCanceledError(err) { log.Error(ctx, "failed to get metadata", slog.Error(err)) - _ = sseSendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Failed to get metadata.", diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 7a64648033c79..a4a0027ae7475 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -1718,12 +1718,33 @@ func (api *API) resolveAutostart(rw http.ResponseWriter, r *http.Request) { // @Param workspace path string true "Workspace ID" format(uuid) // @Success 200 {object} codersdk.Response // @Router /workspaces/{workspace}/watch [get] -func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { +// @Deprecated Use /workspaces/{workspace}/watch-ws instead +func (api *API) watchWorkspaceSSE(rw http.ResponseWriter, r *http.Request) { + api.watchWorkspace(rw, r, httpapi.ServerSentEventSender) +} + +// @Summary Watch workspace by ID via WebSockets +// @ID watch-workspace-by-id-via-websockets +// @Security CoderSessionToken +// @Produce json +// @Tags Workspaces +// @Param workspace path string true "Workspace ID" format(uuid) +// @Success 200 {object} codersdk.ServerSentEvent +// @Router /workspaces/{workspace}/watch-ws [get] +func (api *API) watchWorkspaceWS(rw http.ResponseWriter, r *http.Request) { + api.watchWorkspace(rw, r, httpapi.OneWayWebSocket) +} + +func (api *API) watchWorkspace( + rw http.ResponseWriter, + r *http.Request, + connect httpapi.InitializeConnectionCallback, +) { ctx := r.Context() workspace := httpmw.WorkspaceParam(r) apiKey := httpmw.APIKey(r) - sendEvent, senderClosed, err := httpapi.ServerSentEventSender(rw, r) + sendEvent, senderClosed, err := connect(rw, r) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error setting up server-sent events.", @@ -1739,7 +1760,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { sendUpdate := func(_ context.Context, _ []byte) { workspace, err := api.Database.GetWorkspaceByID(ctx, workspace.ID) if err != nil { - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Internal error fetching workspace.", @@ -1751,7 +1772,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { data, err := api.workspaceData(ctx, []database.Workspace{workspace}) if err != nil { - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Internal error fetching workspace data.", @@ -1761,7 +1782,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { return } if len(data.templates) == 0 { - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Forbidden reading template of selected workspace.", @@ -1778,7 +1799,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { api.Options.AllowWorkspaceRenames, ) if err != nil { - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Internal error converting workspace.", @@ -1786,7 +1807,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { }, }) } - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeData, Data: w, }) @@ -1804,7 +1825,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { sendUpdate(ctx, nil) })) if err != nil { - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Internal error subscribing to workspace events.", @@ -1818,7 +1839,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { // This is required to show whether the workspace is up-to-date. cancelTemplateSubscribe, err := api.Pubsub.Subscribe(watchTemplateChannel(workspace.TemplateID), sendUpdate) if err != nil { - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, Data: codersdk.Response{ Message: "Internal error subscribing to template events.", @@ -1831,7 +1852,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { // An initial ping signals to the request that the server is now ready // and the client can begin servicing a channel with data. - _ = sendEvent(ctx, codersdk.ServerSentEvent{ + _ = sendEvent(codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypePing, }) // Send updated workspace info after connection is established. This avoids From e312932175eb0d3ec38b106843f8dc0851027afc Mon Sep 17 00:00:00 2001 From: Michael Smith Date: Fri, 7 Mar 2025 22:40:09 +0000 Subject: [PATCH 2/7] chore: make gen --- coderd/apidoc/docs.go | 97 ++++++++++++++++++++++++++++++++ coderd/apidoc/swagger.json | 85 ++++++++++++++++++++++++++++ docs/reference/api/schemas.md | 32 +++++++++++ docs/reference/api/workspaces.md | 38 +++++++++++++ 4 files changed, 252 insertions(+) diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 8f90cd5c205a2..97e909dfa6281 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -8293,6 +8293,7 @@ const docTemplate = `{ ], "summary": "Watch for workspace agent metadata updates", "operationId": "watch-for-workspace-agent-metadata-updates", + "deprecated": true, "parameters": [ { "type": "string", @@ -8313,6 +8314,44 @@ const docTemplate = `{ } } }, + "/workspaceagents/{workspaceagent}/watch-metadata-ws": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "produces": [ + "application/json" + ], + "tags": [ + "Agents" + ], + "summary": "Watch for workspace agent metadata updates via WebSockets", + "operationId": "watch-for-workspace-agent-metadata-updates-via-websockets", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace agent ID", + "name": "workspaceagent", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.ServerSentEvent" + } + } + }, + "x-apidocgen": { + "skip": true + } + } + }, "/workspacebuilds/{workspacebuild}": { "get": { "security": [ @@ -9724,6 +9763,7 @@ const docTemplate = `{ ], "summary": "Watch workspace by ID", "operationId": "watch-workspace-by-id", + "deprecated": true, "parameters": [ { "type": "string", @@ -9743,6 +9783,41 @@ const docTemplate = `{ } } } + }, + "/workspaces/{workspace}/watch-ws": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "produces": [ + "application/json" + ], + "tags": [ + "Workspaces" + ], + "summary": "Watch workspace by ID via WebSockets", + "operationId": "watch-workspace-by-id-via-websockets", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace ID", + "name": "workspace", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.ServerSentEvent" + } + } + } + } } }, "definitions": { @@ -14165,6 +14240,28 @@ const docTemplate = `{ } } }, + "codersdk.ServerSentEvent": { + "type": "object", + "properties": { + "data": {}, + "type": { + "$ref": "#/definitions/codersdk.ServerSentEventType" + } + } + }, + "codersdk.ServerSentEventType": { + "type": "string", + "enum": [ + "ping", + "data", + "error" + ], + "x-enum-varnames": [ + "ServerSentEventTypePing", + "ServerSentEventTypeData", + "ServerSentEventTypeError" + ] + }, "codersdk.SessionCountDeploymentStats": { "type": "object", "properties": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index fcfe56d3fc4aa..3de15ab60e2c3 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -7333,6 +7333,7 @@ "tags": ["Agents"], "summary": "Watch for workspace agent metadata updates", "operationId": "watch-for-workspace-agent-metadata-updates", + "deprecated": true, "parameters": [ { "type": "string", @@ -7353,6 +7354,40 @@ } } }, + "/workspaceagents/{workspaceagent}/watch-metadata-ws": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "produces": ["application/json"], + "tags": ["Agents"], + "summary": "Watch for workspace agent metadata updates via WebSockets", + "operationId": "watch-for-workspace-agent-metadata-updates-via-websockets", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace agent ID", + "name": "workspaceagent", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.ServerSentEvent" + } + } + }, + "x-apidocgen": { + "skip": true + } + } + }, "/workspacebuilds/{workspacebuild}": { "get": { "security": [ @@ -8606,6 +8641,7 @@ "tags": ["Workspaces"], "summary": "Watch workspace by ID", "operationId": "watch-workspace-by-id", + "deprecated": true, "parameters": [ { "type": "string", @@ -8625,6 +8661,37 @@ } } } + }, + "/workspaces/{workspace}/watch-ws": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "produces": ["application/json"], + "tags": ["Workspaces"], + "summary": "Watch workspace by ID via WebSockets", + "operationId": "watch-workspace-by-id-via-websockets", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace ID", + "name": "workspace", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.ServerSentEvent" + } + } + } + } } }, "definitions": { @@ -12840,6 +12907,24 @@ } } }, + "codersdk.ServerSentEvent": { + "type": "object", + "properties": { + "data": {}, + "type": { + "$ref": "#/definitions/codersdk.ServerSentEventType" + } + } + }, + "codersdk.ServerSentEventType": { + "type": "string", + "enum": ["ping", "data", "error"], + "x-enum-varnames": [ + "ServerSentEventTypePing", + "ServerSentEventTypeData", + "ServerSentEventTypeError" + ] + }, "codersdk.SessionCountDeploymentStats": { "type": "object", "properties": { diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index 9fa22af7356ae..5767d33e7e3cc 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -5526,6 +5526,38 @@ Git clone makes use of this by parsing the URL from: 'Username for "https://gith | `ssh_config_options` | object | false | | | | » `[any property]` | string | false | | | +## codersdk.ServerSentEvent + +```json +{ + "data": null, + "type": "ping" +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|--------|--------------------------------------------------------------|----------|--------------|-------------| +| `data` | any | false | | | +| `type` | [codersdk.ServerSentEventType](#codersdkserversenteventtype) | false | | | + +## codersdk.ServerSentEventType + +```json +"ping" +``` + +### Properties + +#### Enumerated Values + +| Value | +|---------| +| `ping` | +| `data` | +| `error` | + ## codersdk.SessionCountDeploymentStats ```json diff --git a/docs/reference/api/workspaces.md b/docs/reference/api/workspaces.md index 7264b6dbb3939..18500158567ae 100644 --- a/docs/reference/api/workspaces.md +++ b/docs/reference/api/workspaces.md @@ -1979,3 +1979,41 @@ curl -X GET http://coder-server:8080/api/v2/workspaces/{workspace}/watch \ | 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.Response](schemas.md#codersdkresponse) | To perform this operation, you must be authenticated. [Learn more](authentication.md). + +## Watch workspace by ID via WebSockets + +### Code samples + +```shell +# Example request using curl +curl -X GET http://coder-server:8080/api/v2/workspaces/{workspace}/watch-ws \ + -H 'Accept: application/json' \ + -H 'Coder-Session-Token: API_KEY' +``` + +`GET /workspaces/{workspace}/watch-ws` + +### Parameters + +| Name | In | Type | Required | Description | +|-------------|------|--------------|----------|--------------| +| `workspace` | path | string(uuid) | true | Workspace ID | + +### Example responses + +> 200 Response + +```json +{ + "data": null, + "type": "ping" +} +``` + +### Responses + +| Status | Meaning | Description | Schema | +|--------|---------------------------------------------------------|-------------|----------------------------------------------------------------| +| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.ServerSentEvent](schemas.md#codersdkserversentevent) | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). From 43b16766c19730e74cfd2c07b431f1542e70c3c0 Mon Sep 17 00:00:00 2001 From: Michael Smith Date: Wed, 19 Mar 2025 19:54:13 +0000 Subject: [PATCH 3/7] fix: apply feedback --- coderd/httpapi/httpapi.go | 18 +++++++----------- coderd/httpapi/httpapi_test.go | 16 ++++++++-------- coderd/workspaceagents.go | 4 ++-- coderd/workspaces.go | 4 ++-- 4 files changed, 19 insertions(+), 23 deletions(-) diff --git a/coderd/httpapi/httpapi.go b/coderd/httpapi/httpapi.go index cdae4a01a636f..d43d8a74d131f 100644 --- a/coderd/httpapi/httpapi.go +++ b/coderd/httpapi/httpapi.go @@ -285,7 +285,7 @@ func WebsocketCloseSprintf(format string, vars ...any) string { return msg } -type InitializeConnectionCallback func(rw http.ResponseWriter, r *http.Request) ( +type EventSender func(rw http.ResponseWriter, r *http.Request) ( sendEvent func(sse codersdk.ServerSentEvent) error, done <-chan struct{}, err error, @@ -410,8 +410,8 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) ( return sendEvent, closed, nil } -// OneWayWebSocket establishes a new WebSocket connection that enforces one-way -// communication from the server to the client. +// WebSocketEventSender establishes a new WebSocket connection that enforces +// one-way communication from the server to the client. // // The function returned allows you to send a single message to the client, // while the channel lets you listen for when the connection closes. @@ -422,7 +422,7 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) ( // open a workspace in multiple tabs, the entire UI can start to lock up. // WebSockets have no such limitation, no matter what HTTP protocol was used to // establish the connection. -func OneWayWebSocket(rw http.ResponseWriter, r *http.Request) ( +func WebSocketEventSender(rw http.ResponseWriter, r *http.Request) ( func(event codersdk.ServerSentEvent) error, <-chan struct{}, error, @@ -436,12 +436,8 @@ func OneWayWebSocket(rw http.ResponseWriter, r *http.Request) ( } go Heartbeat(ctx, socket) - type SocketError struct { - Code websocket.StatusCode - Reason string - } eventC := make(chan codersdk.ServerSentEvent) - socketErrC := make(chan SocketError, 1) + socketErrC := make(chan websocket.CloseError, 1) closed := make(chan struct{}) go func() { defer cancel() @@ -477,13 +473,13 @@ func OneWayWebSocket(rw http.ResponseWriter, r *http.Request) ( return } if err != nil { - socketErrC <- SocketError{ + socketErrC <- websocket.CloseError{ Code: websocket.StatusInternalError, Reason: "Unable to process invalid message from client", } return } - socketErrC <- SocketError{ + socketErrC <- websocket.CloseError{ Code: websocket.StatusProtocolError, Reason: "Clients cannot send messages for one-way WebSockets", } diff --git a/coderd/httpapi/httpapi_test.go b/coderd/httpapi/httpapi_test.go index 655704b0ed1ff..9ab64a4a55f10 100644 --- a/coderd/httpapi/httpapi_test.go +++ b/coderd/httpapi/httpapi_test.go @@ -162,7 +162,7 @@ func TestWebsocketCloseMsg(t *testing.T) { } // Our WebSocket library accepts any arbitrary ResponseWriter at the type level, -// but it must also implement http.Hijack +// but the writer must also implement http.Hijacker for long-lived connections type mockWsResponseWriter struct { serverRecorder *httptest.ResponseRecorder serverConn net.Conn @@ -196,7 +196,7 @@ func (w mockWsWrite) Write(b []byte) (int, error) { return w(b) } -func TestOneWayWebSocket(t *testing.T) { +func TestWebSocketEventSender(t *testing.T) { t.Parallel() newBaseRequest := func(ctx context.Context) *http.Request { @@ -256,7 +256,7 @@ func TestOneWayWebSocket(t *testing.T) { req.Proto = p.proto writer := newWebsocketWriter() - _, _, err := httpapi.OneWayWebSocket(writer, req) + _, _, err := httpapi.WebSocketEventSender(writer, req) require.ErrorContains(t, err, p.proto) } }) @@ -267,7 +267,7 @@ func TestOneWayWebSocket(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) req := newBaseRequest(ctx) writer := newWebsocketWriter() - send, _, err := httpapi.OneWayWebSocket(writer, req) + send, _, err := httpapi.WebSocketEventSender(writer, req) require.NoError(t, err) serverPayload := codersdk.ServerSentEvent{ @@ -293,7 +293,7 @@ func TestOneWayWebSocket(t *testing.T) { ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) req := newBaseRequest(ctx) writer := newWebsocketWriter() - _, done, err := httpapi.OneWayWebSocket(writer, req) + _, done, err := httpapi.WebSocketEventSender(writer, req) require.NoError(t, err) successC := make(chan bool) @@ -317,7 +317,7 @@ func TestOneWayWebSocket(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) req := newBaseRequest(ctx) writer := newWebsocketWriter() - _, done, err := httpapi.OneWayWebSocket(writer, req) + _, done, err := httpapi.WebSocketEventSender(writer, req) require.NoError(t, err) successC := make(chan bool) @@ -347,7 +347,7 @@ func TestOneWayWebSocket(t *testing.T) { ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) req := newBaseRequest(ctx) writer := newWebsocketWriter() - send, done, err := httpapi.OneWayWebSocket(writer, req) + send, done, err := httpapi.WebSocketEventSender(writer, req) require.NoError(t, err) successC := make(chan bool) @@ -388,7 +388,7 @@ func TestOneWayWebSocket(t *testing.T) { ctx := testutil.Context(t, timeout) req := newBaseRequest(ctx) writer := newWebsocketWriter() - _, _, err := httpapi.OneWayWebSocket(writer, req) + _, _, err := httpapi.WebSocketEventSender(writer, req) require.NoError(t, err) type Result struct { diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 7234fda2cb75a..278615c3a141e 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -1109,13 +1109,13 @@ func (api *API) watchWorkspaceAgentMetadataSSE(rw http.ResponseWriter, r *http.R // @Router /workspaceagents/{workspaceagent}/watch-metadata-ws [get] // @x-apidocgen {"skip": true} func (api *API) watchWorkspaceAgentMetadataWS(rw http.ResponseWriter, r *http.Request) { - api.watchWorkspaceAgentMetadata(rw, r, httpapi.OneWayWebSocket) + api.watchWorkspaceAgentMetadata(rw, r, httpapi.WebSocketEventSender) } func (api *API) watchWorkspaceAgentMetadata( rw http.ResponseWriter, r *http.Request, - connect httpapi.InitializeConnectionCallback, + connect httpapi.EventSender, ) { // Allow us to interrupt watch via cancel. ctx, cancel := context.WithCancel(r.Context()) diff --git a/coderd/workspaces.go b/coderd/workspaces.go index a4a0027ae7475..c40b3fe34ec5b 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -1732,13 +1732,13 @@ func (api *API) watchWorkspaceSSE(rw http.ResponseWriter, r *http.Request) { // @Success 200 {object} codersdk.ServerSentEvent // @Router /workspaces/{workspace}/watch-ws [get] func (api *API) watchWorkspaceWS(rw http.ResponseWriter, r *http.Request) { - api.watchWorkspace(rw, r, httpapi.OneWayWebSocket) + api.watchWorkspace(rw, r, httpapi.WebSocketEventSender) } func (api *API) watchWorkspace( rw http.ResponseWriter, r *http.Request, - connect httpapi.InitializeConnectionCallback, + connect httpapi.EventSender, ) { ctx := r.Context() workspace := httpmw.WorkspaceParam(r) From c7d95d97be2067643ac1f70e94eab60243ce726f Mon Sep 17 00:00:00 2001 From: Michael Smith Date: Wed, 19 Mar 2025 21:06:33 +0000 Subject: [PATCH 4/7] wip: commit progress on tests --- coderd/httpapi/httpapi.go | 21 ++--- coderd/httpapi/httpapi_test.go | 163 +++++++++++++++++++++++---------- coderd/workspaceagents.go | 2 +- coderd/workspaces.go | 2 +- 4 files changed, 126 insertions(+), 62 deletions(-) diff --git a/coderd/httpapi/httpapi.go b/coderd/httpapi/httpapi.go index d43d8a74d131f..c70290ffe56b0 100644 --- a/coderd/httpapi/httpapi.go +++ b/coderd/httpapi/httpapi.go @@ -326,16 +326,13 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) ( // Synchronized handling of events (no guarantee of order). go func() { defer close(closed) - - // Send a heartbeat every 15 seconds to avoid the connection being killed. - ticker := time.NewTicker(time.Second * 15) + ticker := time.NewTicker(HeartbeatInterval) defer ticker.Stop() for { var event sseEvent - select { - case <-r.Context().Done(): + case <-ctx.Done(): return case event = <-eventC: case <-ticker.C: @@ -357,8 +354,6 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) ( sendEvent := func(newEvent codersdk.ServerSentEvent) error { buf := &bytes.Buffer{} - enc := json.NewEncoder(buf) - _, err := buf.WriteString(fmt.Sprintf("event: %s\n", newEvent.Type)) if err != nil { return err @@ -369,6 +364,8 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) ( if err != nil { return err } + + enc := json.NewEncoder(buf) err = enc.Encode(newEvent.Data) if err != nil { return err @@ -386,8 +383,6 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) ( } select { - case <-r.Context().Done(): - return r.Context().Err() case <-ctx.Done(): return ctx.Err() case <-closed: @@ -397,8 +392,6 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) ( // for early exit. We don't check closed here because it // can't happen while processing the event. select { - case <-r.Context().Done(): - return r.Context().Err() case <-ctx.Done(): return ctx.Err() case err := <-event.errC: @@ -410,8 +403,8 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) ( return sendEvent, closed, nil } -// WebSocketEventSender establishes a new WebSocket connection that enforces -// one-way communication from the server to the client. +// OneWayWebSocketEventSender establishes a new WebSocket connection that +// enforces one-way communication from the server to the client. // // The function returned allows you to send a single message to the client, // while the channel lets you listen for when the connection closes. @@ -422,7 +415,7 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) ( // open a workspace in multiple tabs, the entire UI can start to lock up. // WebSockets have no such limitation, no matter what HTTP protocol was used to // establish the connection. -func WebSocketEventSender(rw http.ResponseWriter, r *http.Request) ( +func OneWayWebSocketEventSender(rw http.ResponseWriter, r *http.Request) ( func(event codersdk.ServerSentEvent) error, <-chan struct{}, error, diff --git a/coderd/httpapi/httpapi_test.go b/coderd/httpapi/httpapi_test.go index 9ab64a4a55f10..aea7fc230be1b 100644 --- a/coderd/httpapi/httpapi_test.go +++ b/coderd/httpapi/httpapi_test.go @@ -162,41 +162,66 @@ func TestWebsocketCloseMsg(t *testing.T) { } // Our WebSocket library accepts any arbitrary ResponseWriter at the type level, -// but the writer must also implement http.Hijacker for long-lived connections -type mockWsResponseWriter struct { +// but the writer must also implement http.Hijacker for long-lived connections. +// The SSE version only requires http.Flusher (no need for the Hijack method). +type mockEventSenderResponseWriter struct { serverRecorder *httptest.ResponseRecorder serverConn net.Conn clientConn net.Conn serverReadWriter *bufio.ReadWriter } -func (m mockWsResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { +func (m mockEventSenderResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return m.serverConn, m.serverReadWriter, nil } -func (m mockWsResponseWriter) Flush() { +func (m mockEventSenderResponseWriter) Flush() { _ = m.serverReadWriter.Flush() } -func (m mockWsResponseWriter) Header() http.Header { +func (m mockEventSenderResponseWriter) Header() http.Header { return m.serverRecorder.Header() } -func (m mockWsResponseWriter) Write(b []byte) (int, error) { +func (m mockEventSenderResponseWriter) Write(b []byte) (int, error) { return m.serverReadWriter.Write(b) } -func (m mockWsResponseWriter) WriteHeader(code int) { +func (m mockEventSenderResponseWriter) WriteHeader(code int) { m.serverRecorder.WriteHeader(code) } -type mockWsWrite func(b []byte) (int, error) +type mockEventSenderWrite func(b []byte) (int, error) -func (w mockWsWrite) Write(b []byte) (int, error) { +func (w mockEventSenderWrite) Write(b []byte) (int, error) { return w(b) } -func TestWebSocketEventSender(t *testing.T) { +func newMockEventSenderWriter() mockEventSenderResponseWriter { + mockServer, mockClient := net.Pipe() + recorder := httptest.NewRecorder() + + var write mockEventSenderWrite = func(b []byte) (int, error) { + serverCount, err := mockServer.Write(b) + if err != nil { + return serverCount, err + } + recorderCount, err := recorder.Write(b) + return min(serverCount, recorderCount), err + } + + return mockEventSenderResponseWriter{ + serverConn: mockServer, + clientConn: mockClient, + serverRecorder: recorder, + serverReadWriter: bufio.NewReadWriter( + bufio.NewReader(mockServer), + bufio.NewWriter(write), + ), + } +} + +func TestOneWayWebSocketEventSender(t *testing.T) { t.Parallel() newBaseRequest := func(ctx context.Context) *http.Request { @@ -213,30 +238,6 @@ func TestWebSocketEventSender(t *testing.T) { return req } - newWebsocketWriter := func() mockWsResponseWriter { - mockServer, mockClient := net.Pipe() - recorder := httptest.NewRecorder() - - var write mockWsWrite = func(b []byte) (int, error) { - serverCount, err := mockServer.Write(b) - if err != nil { - return serverCount, err - } - recorderCount, err := recorder.Write(b) - return min(serverCount, recorderCount), err - } - - return mockWsResponseWriter{ - serverConn: mockServer, - clientConn: mockClient, - serverRecorder: recorder, - serverReadWriter: bufio.NewReadWriter( - bufio.NewReader(mockServer), - bufio.NewWriter(write), - ), - } - } - t.Run("Produces error if the socket connection could not be established", func(t *testing.T) { t.Parallel() @@ -255,8 +256,8 @@ func TestWebSocketEventSender(t *testing.T) { req.ProtoMinor = p.minor req.Proto = p.proto - writer := newWebsocketWriter() - _, _, err := httpapi.WebSocketEventSender(writer, req) + writer := newMockEventSenderWriter() + _, _, err := httpapi.OneWayWebSocketEventSender(writer, req) require.ErrorContains(t, err, p.proto) } }) @@ -266,8 +267,8 @@ func TestWebSocketEventSender(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) req := newBaseRequest(ctx) - writer := newWebsocketWriter() - send, _, err := httpapi.WebSocketEventSender(writer, req) + writer := newMockEventSenderWriter() + send, _, err := httpapi.OneWayWebSocketEventSender(writer, req) require.NoError(t, err) serverPayload := codersdk.ServerSentEvent{ @@ -292,8 +293,8 @@ func TestWebSocketEventSender(t *testing.T) { ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) req := newBaseRequest(ctx) - writer := newWebsocketWriter() - _, done, err := httpapi.WebSocketEventSender(writer, req) + writer := newMockEventSenderWriter() + _, done, err := httpapi.OneWayWebSocketEventSender(writer, req) require.NoError(t, err) successC := make(chan bool) @@ -316,8 +317,8 @@ func TestWebSocketEventSender(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) req := newBaseRequest(ctx) - writer := newWebsocketWriter() - _, done, err := httpapi.WebSocketEventSender(writer, req) + writer := newMockEventSenderWriter() + _, done, err := httpapi.OneWayWebSocketEventSender(writer, req) require.NoError(t, err) successC := make(chan bool) @@ -346,8 +347,8 @@ func TestWebSocketEventSender(t *testing.T) { ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) req := newBaseRequest(ctx) - writer := newWebsocketWriter() - send, done, err := httpapi.WebSocketEventSender(writer, req) + writer := newMockEventSenderWriter() + send, done, err := httpapi.OneWayWebSocketEventSender(writer, req) require.NoError(t, err) successC := make(chan bool) @@ -387,8 +388,8 @@ func TestWebSocketEventSender(t *testing.T) { ctx := testutil.Context(t, timeout) req := newBaseRequest(ctx) - writer := newWebsocketWriter() - _, _, err := httpapi.WebSocketEventSender(writer, req) + writer := newMockEventSenderWriter() + _, _, err := httpapi.OneWayWebSocketEventSender(writer, req) require.NoError(t, err) type Result struct { @@ -420,3 +421,73 @@ func TestWebSocketEventSender(t *testing.T) { require.True(t, result.Success) }) } + +func TestServerSentEventSender(t *testing.T) { + t.Parallel() + + newBaseRequest := func(ctx context.Context) *http.Request { + url := "ws://www.fake-website.com/logs" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + return req + } + + t.Run("Mutates response headers to support SSE connections", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + req := newBaseRequest(ctx) + writer := newMockEventSenderWriter() + _, _, err := httpapi.ServerSentEventSender(writer, req) + require.NoError(t, err) + + h := writer.Header() + require.Equal(t, h.Get("Content-Type"), "text/event-stream") + require.Equal(t, h.Get("Cache-Control"), "no-cache") + require.Equal(t, h.Get("Connection"), "keep-alive") + require.Equal(t, h.Get("X-Accel-Buffering"), "no") + }) + + t.Run("Returned callback can publish new event to SSE connection", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + req := newBaseRequest(ctx) + writer := newMockEventSenderWriter() + send, _, err := httpapi.ServerSentEventSender(writer, req) + require.NoError(t, err) + + serverPayload := codersdk.ServerSentEvent{ + Type: codersdk.ServerSentEventTypeData, + Data: "Blah", + } + err = send(serverPayload) + require.NoError(t, err) + + // The client connection will receive a little bit of additional data on + // top of the main payload. Have to make sure check has tolerance for + // extra data being present + serverBytes, err := json.Marshal(serverPayload) + require.NoError(t, err) + + // This is the part that's breaking + clientBytes, err := io.ReadAll(writer.clientConn) + require.NoError(t, err) + require.True(t, bytes.Contains(clientBytes, serverBytes)) + }) + + t.Run("Signals to outside consumer when connection has been closed", func(t *testing.T) { + t.Parallel() + t.FailNow() + }) + + t.Run("Cancels the entire connection if the request context cancels", func(t *testing.T) { + t.Parallel() + t.FailNow() + }) + + t.Run("Sends a heartbeat to the client on a fixed internal of time to keep connections alive", func(t *testing.T) { + t.Parallel() + t.FailNow() + }) +} diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 278615c3a141e..9810abbc6a1da 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -1109,7 +1109,7 @@ func (api *API) watchWorkspaceAgentMetadataSSE(rw http.ResponseWriter, r *http.R // @Router /workspaceagents/{workspaceagent}/watch-metadata-ws [get] // @x-apidocgen {"skip": true} func (api *API) watchWorkspaceAgentMetadataWS(rw http.ResponseWriter, r *http.Request) { - api.watchWorkspaceAgentMetadata(rw, r, httpapi.WebSocketEventSender) + api.watchWorkspaceAgentMetadata(rw, r, httpapi.OneWayWebSocketEventSender) } func (api *API) watchWorkspaceAgentMetadata( diff --git a/coderd/workspaces.go b/coderd/workspaces.go index c40b3fe34ec5b..1091eaf3253ed 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -1732,7 +1732,7 @@ func (api *API) watchWorkspaceSSE(rw http.ResponseWriter, r *http.Request) { // @Success 200 {object} codersdk.ServerSentEvent // @Router /workspaces/{workspace}/watch-ws [get] func (api *API) watchWorkspaceWS(rw http.ResponseWriter, r *http.Request) { - api.watchWorkspace(rw, r, httpapi.WebSocketEventSender) + api.watchWorkspace(rw, r, httpapi.OneWayWebSocketEventSender) } func (api *API) watchWorkspace( From 792aa2dc6e3799e1a0e703103fd74eba9448fbde Mon Sep 17 00:00:00 2001 From: Michael Smith Date: Wed, 19 Mar 2025 23:08:16 +0000 Subject: [PATCH 5/7] wip: commit more test progress --- coderd/httpapi/httpapi_test.go | 170 +++++++++++++++++++++++---------- 1 file changed, 119 insertions(+), 51 deletions(-) diff --git a/coderd/httpapi/httpapi_test.go b/coderd/httpapi/httpapi_test.go index aea7fc230be1b..0409d1469551d 100644 --- a/coderd/httpapi/httpapi_test.go +++ b/coderd/httpapi/httpapi_test.go @@ -163,31 +163,32 @@ func TestWebsocketCloseMsg(t *testing.T) { // Our WebSocket library accepts any arbitrary ResponseWriter at the type level, // but the writer must also implement http.Hijacker for long-lived connections. -// The SSE version only requires http.Flusher (no need for the Hijack method). -type mockEventSenderResponseWriter struct { +type mockOneWaySocketWriter struct { serverRecorder *httptest.ResponseRecorder serverConn net.Conn clientConn net.Conn serverReadWriter *bufio.ReadWriter + testContext *testing.T } -func (m mockEventSenderResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { +func (m mockOneWaySocketWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return m.serverConn, m.serverReadWriter, nil } -func (m mockEventSenderResponseWriter) Flush() { - _ = m.serverReadWriter.Flush() +func (m mockOneWaySocketWriter) Flush() { + err := m.serverReadWriter.Flush() + require.NoError(m.testContext, err) } -func (m mockEventSenderResponseWriter) Header() http.Header { +func (m mockOneWaySocketWriter) Header() http.Header { return m.serverRecorder.Header() } -func (m mockEventSenderResponseWriter) Write(b []byte) (int, error) { +func (m mockOneWaySocketWriter) Write(b []byte) (int, error) { return m.serverReadWriter.Write(b) } -func (m mockEventSenderResponseWriter) WriteHeader(code int) { +func (m mockOneWaySocketWriter) WriteHeader(code int) { m.serverRecorder.WriteHeader(code) } @@ -197,30 +198,6 @@ func (w mockEventSenderWrite) Write(b []byte) (int, error) { return w(b) } -func newMockEventSenderWriter() mockEventSenderResponseWriter { - mockServer, mockClient := net.Pipe() - recorder := httptest.NewRecorder() - - var write mockEventSenderWrite = func(b []byte) (int, error) { - serverCount, err := mockServer.Write(b) - if err != nil { - return serverCount, err - } - recorderCount, err := recorder.Write(b) - return min(serverCount, recorderCount), err - } - - return mockEventSenderResponseWriter{ - serverConn: mockServer, - clientConn: mockClient, - serverRecorder: recorder, - serverReadWriter: bufio.NewReadWriter( - bufio.NewReader(mockServer), - bufio.NewWriter(write), - ), - } -} - func TestOneWayWebSocketEventSender(t *testing.T) { t.Parallel() @@ -238,6 +215,34 @@ func TestOneWayWebSocketEventSender(t *testing.T) { return req } + newOneWayWriter := func(t *testing.T) mockOneWaySocketWriter { + mockServer, mockClient := net.Pipe() + recorder := httptest.NewRecorder() + + var write mockEventSenderWrite = func(b []byte) (int, error) { + serverCount, err := mockServer.Write(b) + if err != nil { + return 0, err + } + recorderCount, err := recorder.Write(b) + if err != nil { + return 0, err + } + return min(serverCount, recorderCount), nil + } + + return mockOneWaySocketWriter{ + testContext: t, + serverConn: mockServer, + clientConn: mockClient, + serverRecorder: recorder, + serverReadWriter: bufio.NewReadWriter( + bufio.NewReader(mockServer), + bufio.NewWriter(write), + ), + } + } + t.Run("Produces error if the socket connection could not be established", func(t *testing.T) { t.Parallel() @@ -256,7 +261,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) { req.ProtoMinor = p.minor req.Proto = p.proto - writer := newMockEventSenderWriter() + writer := newOneWayWriter(t) _, _, err := httpapi.OneWayWebSocketEventSender(writer, req) require.ErrorContains(t, err, p.proto) } @@ -267,7 +272,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) req := newBaseRequest(ctx) - writer := newMockEventSenderWriter() + writer := newOneWayWriter(t) send, _, err := httpapi.OneWayWebSocketEventSender(writer, req) require.NoError(t, err) @@ -293,7 +298,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) { ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) req := newBaseRequest(ctx) - writer := newMockEventSenderWriter() + writer := newOneWayWriter(t) _, done, err := httpapi.OneWayWebSocketEventSender(writer, req) require.NoError(t, err) @@ -317,7 +322,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) req := newBaseRequest(ctx) - writer := newMockEventSenderWriter() + writer := newOneWayWriter(t) _, done, err := httpapi.OneWayWebSocketEventSender(writer, req) require.NoError(t, err) @@ -347,7 +352,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) { ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) req := newBaseRequest(ctx) - writer := newMockEventSenderWriter() + writer := newOneWayWriter(t) send, done, err := httpapi.OneWayWebSocketEventSender(writer, req) require.NoError(t, err) @@ -388,7 +393,7 @@ func TestOneWayWebSocketEventSender(t *testing.T) { ctx := testutil.Context(t, timeout) req := newBaseRequest(ctx) - writer := newMockEventSenderWriter() + writer := newOneWayWriter(t) _, _, err := httpapi.OneWayWebSocketEventSender(writer, req) require.NoError(t, err) @@ -422,6 +427,42 @@ func TestOneWayWebSocketEventSender(t *testing.T) { }) } +// ServerSentEventSender accepts any arbitrary ResponseWriter at the type level, +// but the writer must also implement http.Flusher for long-lived connections +type mockServerSentWriter struct { + serverRecorder *httptest.ResponseRecorder + serverConn net.Conn + clientConn net.Conn + buffer *bytes.Buffer + testContext *testing.T +} + +func (m mockServerSentWriter) Flush() { + b := m.buffer.Bytes() + _, err := m.serverConn.Write(b) + require.NoError(m.testContext, err) + m.buffer.Reset() + + // Must close server connection to indicate EOF for any reads from the + // client connection; otherwise reads block forever. This is a testing + // limitation compared to the one-way websockets, since we have no way to + // frame the data and auto-indicate EOF for each message + err = m.serverConn.Close() + require.NoError(m.testContext, err) +} + +func (m mockServerSentWriter) Header() http.Header { + return m.serverRecorder.Header() +} + +func (m mockServerSentWriter) Write(b []byte) (int, error) { + return m.buffer.Write(b) +} + +func (m mockServerSentWriter) WriteHeader(code int) { + m.serverRecorder.WriteHeader(code) +} + func TestServerSentEventSender(t *testing.T) { t.Parallel() @@ -432,12 +473,23 @@ func TestServerSentEventSender(t *testing.T) { return req } + newServerSentWriter := func(t *testing.T) mockServerSentWriter { + mockServer, mockClient := net.Pipe() + return mockServerSentWriter{ + testContext: t, + serverRecorder: httptest.NewRecorder(), + clientConn: mockClient, + serverConn: mockServer, + buffer: &bytes.Buffer{}, + } + } + t.Run("Mutates response headers to support SSE connections", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) req := newBaseRequest(ctx) - writer := newMockEventSenderWriter() + writer := newServerSentWriter(t) _, _, err := httpapi.ServerSentEventSender(writer, req) require.NoError(t, err) @@ -453,7 +505,7 @@ func TestServerSentEventSender(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) req := newBaseRequest(ctx) - writer := newMockEventSenderWriter() + writer := newServerSentWriter(t) send, _, err := httpapi.ServerSentEventSender(writer, req) require.NoError(t, err) @@ -464,30 +516,46 @@ func TestServerSentEventSender(t *testing.T) { err = send(serverPayload) require.NoError(t, err) - // The client connection will receive a little bit of additional data on - // top of the main payload. Have to make sure check has tolerance for - // extra data being present - serverBytes, err := json.Marshal(serverPayload) - require.NoError(t, err) - - // This is the part that's breaking clientBytes, err := io.ReadAll(writer.clientConn) require.NoError(t, err) - require.True(t, bytes.Contains(clientBytes, serverBytes)) + require.Equal( + t, + string(clientBytes), + "event: data\ndata: \"Blah\"\n\n", + ) }) t.Run("Signals to outside consumer when connection has been closed", func(t *testing.T) { t.Parallel() - t.FailNow() + + ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) + req := newBaseRequest(ctx) + writer := newServerSentWriter(t) + _, done, err := httpapi.ServerSentEventSender(writer, req) + require.NoError(t, err) + + successC := make(chan bool) + ticker := time.NewTicker(testutil.WaitShort) + go func() { + select { + case <-done: + successC <- true + case <-ticker.C: + successC <- false + } + }() + + cancel() + require.True(t, <-successC) }) t.Run("Cancels the entire connection if the request context cancels", func(t *testing.T) { - t.Parallel() t.FailNow() + t.Parallel() }) t.Run("Sends a heartbeat to the client on a fixed internal of time to keep connections alive", func(t *testing.T) { - t.Parallel() t.FailNow() + t.Parallel() }) } From bcd14296670abd68bc60d2c9dd247dd8a4075ae0 Mon Sep 17 00:00:00 2001 From: Michael Smith Date: Wed, 19 Mar 2025 23:51:15 +0000 Subject: [PATCH 6/7] chore: get all tests passing --- coderd/httpapi/httpapi_test.go | 46 +++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/coderd/httpapi/httpapi_test.go b/coderd/httpapi/httpapi_test.go index 0409d1469551d..44675e78a255d 100644 --- a/coderd/httpapi/httpapi_test.go +++ b/coderd/httpapi/httpapi_test.go @@ -549,13 +549,47 @@ func TestServerSentEventSender(t *testing.T) { require.True(t, <-successC) }) - t.Run("Cancels the entire connection if the request context cancels", func(t *testing.T) { - t.FailNow() - t.Parallel() - }) - t.Run("Sends a heartbeat to the client on a fixed internal of time to keep connections alive", func(t *testing.T) { - t.FailNow() t.Parallel() + + // Need add at least three heartbeats for something to be reliably + // counted as an interval, but also need some wiggle room + heartbeatCount := 3 + hbDuration := time.Duration(heartbeatCount) * httpapi.HeartbeatInterval + timeout := hbDuration + (5 * time.Second) + + ctx := testutil.Context(t, timeout) + req := newBaseRequest(ctx) + writer := newServerSentWriter(t) + _, _, err := httpapi.ServerSentEventSender(writer, req) + require.NoError(t, err) + + type Result struct { + Err error + Success bool + } + resultC := make(chan Result) + go func() { + err := writer. + clientConn. + SetReadDeadline(time.Now().Add(timeout)) + if err != nil { + resultC <- Result{err, false} + return + } + for range heartbeatCount { + pingBuffer := make([]byte, 1) + pingSize, err := writer.clientConn.Read(pingBuffer) + if err != nil || pingSize != 1 { + resultC <- Result{err, false} + return + } + } + resultC <- Result{nil, true} + }() + + result := <-resultC + require.NoError(t, result.Err) + require.True(t, result.Success) }) } From 8f1eca0c91f38987ed0d18c24291c3d493d1ad0e Mon Sep 17 00:00:00 2001 From: Michael Smith Date: Fri, 28 Mar 2025 16:17:26 -0400 Subject: [PATCH 7/7] chore: add support for one-way WebSockets to UI (#16855) Closes https://github.com/coder/coder/issues/16777 ## Changes made - Added `OneWayWebSocket` utility class to help enforce one-way communication from the server to the client - Updated all client client code to use the new WebSocket-based endpoints made to replace the current SSE-based endpoints - Updated WebSocket event handlers to be aware of new protocols - Refactored existing `useEffect` calls and removed some synchronization bugs - Removed dependencies and types for dealing with SSEs - Addressed some minor Biome warnings --- site/package.json | 1 - site/pnpm-lock.yaml | 8 - site/src/@types/eventsourcemock.d.ts | 1 - site/src/api/api.ts | 76 +-- .../NotificationsInbox/NotificationsInbox.tsx | 36 +- site/src/modules/resources/AgentMetadata.tsx | 93 ++-- .../modules/templates/useWatchVersionLogs.ts | 42 +- .../WorkspacePage/WorkspacePage.test.tsx | 20 +- .../src/pages/WorkspacePage/WorkspacePage.tsx | 27 +- site/src/utils/OneWayWebSocket.test.ts | 492 ++++++++++++++++++ site/src/utils/OneWayWebSocket.ts | 198 +++++++ 11 files changed, 843 insertions(+), 151 deletions(-) delete mode 100644 site/src/@types/eventsourcemock.d.ts create mode 100644 site/src/utils/OneWayWebSocket.test.ts create mode 100644 site/src/utils/OneWayWebSocket.ts diff --git a/site/package.json b/site/package.json index 7f45637237cf7..51ec024ae2fa1 100644 --- a/site/package.json +++ b/site/package.json @@ -166,7 +166,6 @@ "@vitejs/plugin-react": "4.3.4", "autoprefixer": "10.4.20", "chromatic": "11.25.2", - "eventsourcemock": "2.0.0", "express": "4.21.2", "jest": "29.7.0", "jest-canvas-mock": "2.5.2", diff --git a/site/pnpm-lock.yaml b/site/pnpm-lock.yaml index d08ab3c523083..fc5dbb43876f6 100644 --- a/site/pnpm-lock.yaml +++ b/site/pnpm-lock.yaml @@ -403,9 +403,6 @@ importers: chromatic: specifier: 11.25.2 version: 11.25.2 - eventsourcemock: - specifier: 2.0.0 - version: 2.0.0 express: specifier: 4.21.2 version: 4.21.2 @@ -3796,9 +3793,6 @@ packages: eventemitter3@4.0.7: resolution: {integrity: sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==, tarball: https://registry.npmjs.org/eventemitter3/-/eventemitter3-4.0.7.tgz} - eventsourcemock@2.0.0: - resolution: {integrity: sha512-tSmJnuE+h6A8/hLRg0usf1yL+Q8w01RQtmg0Uzgoxk/HIPZrIUeAr/A4es/8h1wNsoG8RdiESNQLTKiNwbSC3Q==, tarball: https://registry.npmjs.org/eventsourcemock/-/eventsourcemock-2.0.0.tgz} - execa@5.1.1: resolution: {integrity: sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg==, tarball: https://registry.npmjs.org/execa/-/execa-5.1.1.tgz} engines: {node: '>=10'} @@ -10017,8 +10011,6 @@ snapshots: eventemitter3@4.0.7: {} - eventsourcemock@2.0.0: {} - execa@5.1.1: dependencies: cross-spawn: 7.0.6 diff --git a/site/src/@types/eventsourcemock.d.ts b/site/src/@types/eventsourcemock.d.ts deleted file mode 100644 index 296c4f19c33ce..0000000000000 --- a/site/src/@types/eventsourcemock.d.ts +++ /dev/null @@ -1 +0,0 @@ -declare module "eventsourcemock"; diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 85953bbce736f..3a43772a02657 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -22,9 +22,10 @@ import globalAxios, { type AxiosInstance, isAxiosError } from "axios"; import type dayjs from "dayjs"; import userAgentParser from "ua-parser-js"; +import { OneWayWebSocket } from "utils/OneWayWebSocket"; import { delay } from "../utils/delay"; -import * as TypesGen from "./typesGenerated"; import type { PostWorkspaceUsageRequest } from "./typesGenerated"; +import * as TypesGen from "./typesGenerated"; const getMissingParameters = ( oldBuildParameters: TypesGen.WorkspaceBuildParameter[], @@ -101,61 +102,40 @@ const getMissingParameters = ( }; /** - * * @param agentId - * @returns An EventSource that emits agent metadata event objects - * (ServerSentEvent) + * @returns {OneWayWebSocket} A OneWayWebSocket that emits Server-Sent Events. */ -export const watchAgentMetadata = (agentId: string): EventSource => { - return new EventSource( - `${location.protocol}//${location.host}/api/v2/workspaceagents/${agentId}/watch-metadata`, - { withCredentials: true }, - ); +export const watchAgentMetadata = ( + agentId: string, +): OneWayWebSocket => { + return new OneWayWebSocket({ + apiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata-ws`, + }); }; /** - * @returns {EventSource} An EventSource that emits workspace event objects - * (ServerSentEvent) + * @returns {OneWayWebSocket} A OneWayWebSocket that emits Server-Sent Events. */ -export const watchWorkspace = (workspaceId: string): EventSource => { - return new EventSource( - `${location.protocol}//${location.host}/api/v2/workspaces/${workspaceId}/watch`, - { withCredentials: true }, - ); +export const watchWorkspace = ( + workspaceId: string, +): OneWayWebSocket => { + return new OneWayWebSocket({ + apiRoute: `/api/v2/workspaces/${workspaceId}/watch-ws`, + }); }; -type WatchInboxNotificationsParams = { +type WatchInboxNotificationsParams = Readonly<{ read_status?: "read" | "unread" | "all"; -}; +}>; -export const watchInboxNotifications = ( - onNewNotification: (res: TypesGen.GetInboxNotificationResponse) => void, +export function watchInboxNotifications( params?: WatchInboxNotificationsParams, -) => { - const searchParams = new URLSearchParams(params); - const socket = createWebSocket( - "/api/v2/notifications/inbox/watch", - searchParams, - ); - - socket.addEventListener("message", (event) => { - try { - const res = JSON.parse( - event.data, - ) as TypesGen.GetInboxNotificationResponse; - onNewNotification(res); - } catch (error) { - console.warn("Error parsing inbox notification: ", error); - } - }); - - socket.addEventListener("error", (event) => { - console.warn("Watch inbox notifications error: ", event); - socket.close(); +): OneWayWebSocket { + return new OneWayWebSocket({ + apiRoute: "/api/v2/notifications/inbox/watch", + searchParams: params, }); - - return socket; -}; +} export const getURLWithSearchParams = ( basePath: string, @@ -1125,7 +1105,7 @@ class ApiMethods { }; getWorkspaceByOwnerAndName = async ( - username = "me", + username: string, workspaceName: string, params?: TypesGen.WorkspaceOptions, ): Promise => { @@ -1138,7 +1118,7 @@ class ApiMethods { }; getWorkspaceBuildByNumber = async ( - username = "me", + username: string, workspaceName: string, buildNumber: number, ): Promise => { @@ -1324,7 +1304,7 @@ class ApiMethods { }; createWorkspace = async ( - userId = "me", + userId: string, workspace: TypesGen.CreateWorkspaceRequest, ): Promise => { const response = await this.axios.post( @@ -2542,7 +2522,7 @@ function createWebSocket( ) { const protocol = location.protocol === "https:" ? "wss:" : "ws:"; const socket = new WebSocket( - `${protocol}//${location.host}${path}?${params.toString()}`, + `${protocol}//${location.host}${path}?${params}`, ); socket.binaryType = "blob"; return socket; diff --git a/site/src/modules/notifications/NotificationsInbox/NotificationsInbox.tsx b/site/src/modules/notifications/NotificationsInbox/NotificationsInbox.tsx index 656d87fbe31d3..cdbf0941b7fdb 100644 --- a/site/src/modules/notifications/NotificationsInbox/NotificationsInbox.tsx +++ b/site/src/modules/notifications/NotificationsInbox/NotificationsInbox.tsx @@ -61,21 +61,31 @@ export const NotificationsInbox: FC = ({ ); useEffect(() => { - const socket = watchInboxNotifications( - (res) => { - updateNotificationsCache((prev) => { - return { - unread_count: res.unread_count, - notifications: [res.notification, ...prev.notifications], - }; - }); - }, - { read_status: "unread" }, - ); + const socket = watchInboxNotifications({ read_status: "unread" }); - return () => { + socket.addEventListener("message", (e) => { + if (e.parseError) { + console.warn("Error parsing inbox notification: ", e.parseError); + return; + } + + const msg = e.parsedMessage; + updateNotificationsCache((current) => { + return { + unread_count: msg.unread_count, + notifications: [msg.notification, ...current.notifications], + }; + }); + }); + + socket.addEventListener("error", () => { + displayError( + "Unable to retrieve latest inbox notifications. Please try refreshing the browser.", + ); socket.close(); - }; + }); + + return () => socket.close(); }, [updateNotificationsCache]); const { diff --git a/site/src/modules/resources/AgentMetadata.tsx b/site/src/modules/resources/AgentMetadata.tsx index 81b5a14994e81..5e5501809ee49 100644 --- a/site/src/modules/resources/AgentMetadata.tsx +++ b/site/src/modules/resources/AgentMetadata.tsx @@ -3,9 +3,11 @@ import Skeleton from "@mui/material/Skeleton"; import Tooltip from "@mui/material/Tooltip"; import { watchAgentMetadata } from "api/api"; import type { + ServerSentEvent, WorkspaceAgent, WorkspaceAgentMetadata, } from "api/typesGenerated"; +import { displayError } from "components/GlobalSnackbar/utils"; import { Stack } from "components/Stack/Stack"; import dayjs from "dayjs"; import { @@ -17,6 +19,7 @@ import { useState, } from "react"; import { MONOSPACE_FONT_FAMILY } from "theme/constants"; +import type { OneWayWebSocket } from "utils/OneWayWebSocket"; type ItemStatus = "stale" | "valid" | "loading"; @@ -42,50 +45,82 @@ interface AgentMetadataProps { storybookMetadata?: WorkspaceAgentMetadata[]; } +const maxSocketErrorRetryCount = 3; + export const AgentMetadata: FC = ({ agent, storybookMetadata, }) => { - const [metadata, setMetadata] = useState< - WorkspaceAgentMetadata[] | undefined - >(undefined); - + const [activeMetadata, setActiveMetadata] = useState(storybookMetadata); useEffect(() => { + // This is an unfortunate pitfall with this component's testing setup, + // but even though we use the value of storybookMetadata as the initial + // value of the activeMetadata, we cannot put activeMetadata itself into + // the dependency array. If we did, we would destroy and rebuild each + // connection every single time a new message comes in from the socket, + // because the socket has to be wired up to the state setter if (storybookMetadata !== undefined) { - setMetadata(storybookMetadata); return; } - let timeout: ReturnType | undefined = undefined; - - const connect = (): (() => void) => { - const source = watchAgentMetadata(agent.id); + let timeoutId: number | undefined = undefined; + let activeSocket: OneWayWebSocket | null = null; + let retries = 0; + + const createNewConnection = () => { + const socket = watchAgentMetadata(agent.id); + activeSocket = socket; + + socket.addEventListener("error", () => { + setActiveMetadata(undefined); + window.clearTimeout(timeoutId); + + // The error event is supposed to fire when an error happens + // with the connection itself, which implies that the connection + // would auto-close. Couldn't find a definitive answer on MDN, + // though, so closing it manually just to be safe + socket.close(); + activeSocket = null; + + retries++; + if (retries >= maxSocketErrorRetryCount) { + displayError( + "Unexpected disconnect while watching Metadata changes. Please try refreshing the page.", + ); + return; + } - source.onerror = (e) => { - console.error("received error in watch stream", e); - setMetadata(undefined); - source.close(); + displayError( + "Unexpected disconnect while watching Metadata changes. Creating new connection...", + ); + timeoutId = window.setTimeout(() => { + createNewConnection(); + }, 3_000); + }); - timeout = setTimeout(() => { - connect(); - }, 3000); - }; + socket.addEventListener("message", (e) => { + if (e.parseError) { + displayError( + "Unable to process newest response from server. Please try refreshing the page.", + ); + return; + } - source.addEventListener("data", (e) => { - const data = JSON.parse(e.data); - setMetadata(data); - }); - return () => { - if (timeout !== undefined) { - clearTimeout(timeout); + const msg = e.parsedMessage; + if (msg.type === "data") { + setActiveMetadata(msg.data as WorkspaceAgentMetadata[]); } - source.close(); - }; + }); + }; + + createNewConnection(); + return () => { + window.clearTimeout(timeoutId); + activeSocket?.close(); }; - return connect(); }, [agent.id, storybookMetadata]); - if (metadata === undefined) { + if (activeMetadata === undefined) { return (
@@ -93,7 +128,7 @@ export const AgentMetadata: FC = ({ ); } - return ; + return ; }; export const AgentMetadataSkeleton: FC = () => { diff --git a/site/src/modules/templates/useWatchVersionLogs.ts b/site/src/modules/templates/useWatchVersionLogs.ts index 5574e083a9849..1e77b0eb1b073 100644 --- a/site/src/modules/templates/useWatchVersionLogs.ts +++ b/site/src/modules/templates/useWatchVersionLogs.ts @@ -1,46 +1,38 @@ import { watchBuildLogsByTemplateVersionId } from "api/api"; import type { ProvisionerJobLog, TemplateVersion } from "api/typesGenerated"; +import { useEffectEvent } from "hooks/hookPolyfills"; import { useEffect, useState } from "react"; export const useWatchVersionLogs = ( templateVersion: TemplateVersion | undefined, options?: { onDone: () => Promise }, ) => { - const [logs, setLogs] = useState(); + const [logs, setLogs] = useState(); const templateVersionId = templateVersion?.id; - const templateVersionStatus = templateVersion?.job.status; + const [cachedVersionId, setCachedVersionId] = useState(templateVersionId); + if (cachedVersionId !== templateVersionId) { + setCachedVersionId(templateVersionId); + setLogs([]); + } - // biome-ignore lint/correctness/useExhaustiveDependencies: consider refactoring + const stableOnDone = useEffectEvent(() => options?.onDone()); + const status = templateVersion?.job.status; + const canWatch = status === "running" || status === "pending"; useEffect(() => { - setLogs(undefined); - }, [templateVersionId]); - - useEffect(() => { - if (!templateVersionId || !templateVersionStatus) { - return; - } - - if ( - templateVersionStatus !== "running" && - templateVersionStatus !== "pending" - ) { + if (!templateVersionId || !canWatch) { return; } const socket = watchBuildLogsByTemplateVersionId(templateVersionId, { - onMessage: (log) => { - setLogs((logs) => (logs ? [...logs, log] : [log])); - }, - onDone: options?.onDone, - onError: (error) => { - console.error(error); + onError: (error) => console.error(error), + onDone: stableOnDone, + onMessage: (newLog) => { + setLogs((current) => [...(current ?? []), newLog]); }, }); - return () => { - socket.close(); - }; - }, [options?.onDone, templateVersionId, templateVersionStatus]); + return () => socket.close(); + }, [stableOnDone, canWatch, templateVersionId]); return logs; }; diff --git a/site/src/pages/WorkspacePage/WorkspacePage.test.tsx b/site/src/pages/WorkspacePage/WorkspacePage.test.tsx index 50f47a4721320..d120ad5546c17 100644 --- a/site/src/pages/WorkspacePage/WorkspacePage.test.tsx +++ b/site/src/pages/WorkspacePage/WorkspacePage.test.tsx @@ -2,7 +2,7 @@ import { screen, waitFor, within } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; import * as apiModule from "api/api"; import type { TemplateVersionParameter, Workspace } from "api/typesGenerated"; -import EventSourceMock from "eventsourcemock"; +import MockServerSocket from "jest-websocket-mock"; import { DashboardContext, type DashboardProvider, @@ -84,23 +84,11 @@ const testButton = async ( const user = userEvent.setup(); await user.click(button); - expect(actionMock).toBeCalled(); + expect(actionMock).toHaveBeenCalled(); }; -let originalEventSource: typeof window.EventSource; - -beforeAll(() => { - originalEventSource = window.EventSource; - // mocking out EventSource for SSE - window.EventSource = EventSourceMock; -}); - -beforeEach(() => { - jest.resetAllMocks(); -}); - -afterAll(() => { - window.EventSource = originalEventSource; +afterEach(() => { + MockServerSocket.clean(); }); describe("WorkspacePage", () => { diff --git a/site/src/pages/WorkspacePage/WorkspacePage.tsx b/site/src/pages/WorkspacePage/WorkspacePage.tsx index cd2b5f48cb6d3..a55971abfb576 100644 --- a/site/src/pages/WorkspacePage/WorkspacePage.tsx +++ b/site/src/pages/WorkspacePage/WorkspacePage.tsx @@ -5,6 +5,7 @@ import { workspaceBuildsKey } from "api/queries/workspaceBuilds"; import { workspaceByOwnerAndName } from "api/queries/workspaces"; import type { Workspace } from "api/typesGenerated"; import { ErrorAlert } from "components/Alert/ErrorAlert"; +import { displayError } from "components/GlobalSnackbar/utils"; import { Loader } from "components/Loader/Loader"; import { Margins } from "components/Margins/Margins"; import { useEffectEvent } from "hooks/hookPolyfills"; @@ -82,20 +83,26 @@ export const WorkspacePage: FC = () => { return; } - const eventSource = watchWorkspace(workspaceId); + const socket = watchWorkspace(workspaceId); + socket.addEventListener("message", (event) => { + if (event.parseError) { + displayError( + "Unable to process latest data from the server. Please try refreshing the page.", + ); + return; + } - eventSource.addEventListener("data", async (event) => { - const newWorkspaceData = JSON.parse(event.data) as Workspace; - await updateWorkspaceData(newWorkspaceData); + if (event.parsedMessage.type === "data") { + updateWorkspaceData(event.parsedMessage.data as Workspace); + } }); - - eventSource.addEventListener("error", (event) => { - console.error("Error on getting workspace changes.", event); + socket.addEventListener("error", () => { + displayError( + "Unable to get workspace changes. Connection has been closed.", + ); }); - return () => { - eventSource.close(); - }; + return () => socket.close(); }, [updateWorkspaceData, workspaceId]); // Page statuses diff --git a/site/src/utils/OneWayWebSocket.test.ts b/site/src/utils/OneWayWebSocket.test.ts new file mode 100644 index 0000000000000..c6b00b593111f --- /dev/null +++ b/site/src/utils/OneWayWebSocket.test.ts @@ -0,0 +1,492 @@ +/** + * @file Sets up unit tests for OneWayWebSocket. + * + * 2025-03-18 - Really wanted to define these as integration tests with MSW, but + * getting it set up correctly for Jest and JSDOM got a little screwy. That can + * be revisited in the future, but in the meantime, we're assuming that the base + * WebSocket class doesn't have any bugs, and can safely be mocked out. + */ + +import { + type OneWayMessageEvent, + OneWayWebSocket, + type WebSocketEventType, +} from "./OneWayWebSocket"; + +type MockPublisher = Readonly<{ + publishMessage: (event: MessageEvent) => void; + publishError: (event: ErrorEvent) => void; + publishClose: (event: CloseEvent) => void; + publishOpen: (event: Event) => void; +}>; + +function createMockWebSocket( + url: string, + protocols?: string | string[], +): readonly [WebSocket, MockPublisher] { + type EventMap = { + message: MessageEvent; + error: ErrorEvent; + close: CloseEvent; + open: Event; + }; + type CallbackStore = { + [K in keyof EventMap]: ((event: EventMap[K]) => void)[]; + }; + + let activeProtocol: string; + if (Array.isArray(protocols)) { + activeProtocol = protocols[0] ?? ""; + } else if (typeof protocols === "string") { + activeProtocol = protocols; + } else { + activeProtocol = ""; + } + + let closed = false; + const store: CallbackStore = { + message: [], + error: [], + close: [], + open: [], + }; + + const mockSocket: WebSocket = { + CONNECTING: 0, + OPEN: 1, + CLOSING: 2, + CLOSED: 3, + + url, + protocol: activeProtocol, + readyState: 1, + binaryType: "blob", + bufferedAmount: 0, + extensions: "", + onclose: null, + onerror: null, + onmessage: null, + onopen: null, + send: jest.fn(), + dispatchEvent: jest.fn(), + + addEventListener: ( + eventType: E, + callback: WebSocketEventMap[E], + ) => { + if (closed) { + return; + } + + const subscribers = store[eventType]; + const cb = callback as unknown as CallbackStore[E][0]; + if (!subscribers.includes(cb)) { + subscribers.push(cb); + } + }, + + removeEventListener: ( + eventType: E, + callback: WebSocketEventMap[E], + ) => { + if (closed) { + return; + } + + const subscribers = store[eventType]; + const cb = callback as unknown as CallbackStore[E][0]; + if (subscribers.includes(cb)) { + const updated = store[eventType].filter((c) => c !== cb); + store[eventType] = updated as unknown as CallbackStore[E]; + } + }, + + close: () => { + closed = true; + }, + }; + + const publisher: MockPublisher = { + publishOpen: (event) => { + if (closed) { + return; + } + for (const sub of store.open) { + sub(event); + } + }, + + publishError: (event) => { + if (closed) { + return; + } + for (const sub of store.error) { + sub(event); + } + }, + + publishMessage: (event) => { + if (closed) { + return; + } + for (const sub of store.message) { + sub(event); + } + }, + + publishClose: (event) => { + if (closed) { + return; + } + for (const sub of store.close) { + sub(event); + } + }, + }; + + return [mockSocket, publisher] as const; +} + +describe(OneWayWebSocket.name, () => { + const dummyRoute = "/api/v2/blah"; + + it("Errors out if API route does not start with '/api/v2/'", () => { + const testRoutes: string[] = ["blah", "", "/", "/api", "/api/v225"]; + + for (const r of testRoutes) { + expect(() => { + new OneWayWebSocket({ + apiRoute: r, + websocketInit: (url, protocols) => { + const [socket] = createMockWebSocket(url, protocols); + return socket; + }, + }); + }).toThrow(Error); + } + }); + + it("Lets a consumer add an event listener of each type", () => { + let publisher!: MockPublisher; + const oneWay = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket, pub] = createMockWebSocket(url, protocols); + publisher = pub; + return socket; + }, + }); + + const onOpen = jest.fn(); + const onClose = jest.fn(); + const onError = jest.fn(); + const onMessage = jest.fn(); + + oneWay.addEventListener("open", onOpen); + oneWay.addEventListener("close", onClose); + oneWay.addEventListener("error", onError); + oneWay.addEventListener("message", onMessage); + + publisher.publishOpen(new Event("open")); + publisher.publishClose(new CloseEvent("close")); + publisher.publishError( + new ErrorEvent("error", { + error: new Error("Whoops - connection broke"), + }), + ); + publisher.publishMessage( + new MessageEvent("message", { + data: "null", + }), + ); + + expect(onOpen).toHaveBeenCalledTimes(1); + expect(onClose).toHaveBeenCalledTimes(1); + expect(onError).toHaveBeenCalledTimes(1); + expect(onMessage).toHaveBeenCalledTimes(1); + }); + + it("Lets a consumer remove an event listener of each type", () => { + let publisher!: MockPublisher; + const oneWay = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket, pub] = createMockWebSocket(url, protocols); + publisher = pub; + return socket; + }, + }); + + const onOpen = jest.fn(); + const onClose = jest.fn(); + const onError = jest.fn(); + const onMessage = jest.fn(); + + oneWay.addEventListener("open", onOpen); + oneWay.addEventListener("close", onClose); + oneWay.addEventListener("error", onError); + oneWay.addEventListener("message", onMessage); + + oneWay.removeEventListener("open", onOpen); + oneWay.removeEventListener("close", onClose); + oneWay.removeEventListener("error", onError); + oneWay.removeEventListener("message", onMessage); + + publisher.publishOpen(new Event("open")); + publisher.publishClose(new CloseEvent("close")); + publisher.publishError( + new ErrorEvent("error", { + error: new Error("Whoops - connection broke"), + }), + ); + publisher.publishMessage( + new MessageEvent("message", { + data: "null", + }), + ); + + expect(onOpen).toHaveBeenCalledTimes(0); + expect(onClose).toHaveBeenCalledTimes(0); + expect(onError).toHaveBeenCalledTimes(0); + expect(onMessage).toHaveBeenCalledTimes(0); + }); + + it("Only calls each callback once if callback is added multiple times", () => { + let publisher!: MockPublisher; + const oneWay = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket, pub] = createMockWebSocket(url, protocols); + publisher = pub; + return socket; + }, + }); + + const onOpen = jest.fn(); + const onClose = jest.fn(); + const onError = jest.fn(); + const onMessage = jest.fn(); + + for (let i = 0; i < 10; i++) { + oneWay.addEventListener("open", onOpen); + oneWay.addEventListener("close", onClose); + oneWay.addEventListener("error", onError); + oneWay.addEventListener("message", onMessage); + } + + publisher.publishOpen(new Event("open")); + publisher.publishClose(new CloseEvent("close")); + publisher.publishError( + new ErrorEvent("error", { + error: new Error("Whoops - connection broke"), + }), + ); + publisher.publishMessage( + new MessageEvent("message", { + data: "null", + }), + ); + + expect(onOpen).toHaveBeenCalledTimes(1); + expect(onClose).toHaveBeenCalledTimes(1); + expect(onError).toHaveBeenCalledTimes(1); + expect(onMessage).toHaveBeenCalledTimes(1); + }); + + it("Lets consumers register multiple callbacks for each event type", () => { + let publisher!: MockPublisher; + const oneWay = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket, pub] = createMockWebSocket(url, protocols); + publisher = pub; + return socket; + }, + }); + + const onOpen1 = jest.fn(); + const onClose1 = jest.fn(); + const onError1 = jest.fn(); + const onMessage1 = jest.fn(); + oneWay.addEventListener("open", onOpen1); + oneWay.addEventListener("close", onClose1); + oneWay.addEventListener("error", onError1); + oneWay.addEventListener("message", onMessage1); + + const onOpen2 = jest.fn(); + const onClose2 = jest.fn(); + const onError2 = jest.fn(); + const onMessage2 = jest.fn(); + oneWay.addEventListener("open", onOpen2); + oneWay.addEventListener("close", onClose2); + oneWay.addEventListener("error", onError2); + oneWay.addEventListener("message", onMessage2); + + publisher.publishOpen(new Event("open")); + publisher.publishClose(new CloseEvent("close")); + publisher.publishError( + new ErrorEvent("error", { + error: new Error("Whoops - connection broke"), + }), + ); + publisher.publishMessage( + new MessageEvent("message", { + data: "null", + }), + ); + + expect(onOpen1).toHaveBeenCalledTimes(1); + expect(onClose1).toHaveBeenCalledTimes(1); + expect(onError1).toHaveBeenCalledTimes(1); + expect(onMessage1).toHaveBeenCalledTimes(1); + + expect(onOpen2).toHaveBeenCalledTimes(1); + expect(onClose2).toHaveBeenCalledTimes(1); + expect(onError2).toHaveBeenCalledTimes(1); + expect(onMessage2).toHaveBeenCalledTimes(1); + }); + + it("Computes the socket protocol based on the browser location protocol", () => { + const oneWay1 = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket] = createMockWebSocket(url, protocols); + return socket; + }, + location: { + protocol: "https:", + host: "www.cool.com", + }, + }); + const oneWay2 = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket] = createMockWebSocket(url, protocols); + return socket; + }, + location: { + protocol: "http:", + host: "www.cool.com", + }, + }); + + expect(oneWay1.url).toMatch(/^wss:\/\//); + expect(oneWay2.url).toMatch(/^ws:\/\//); + }); + + it("Gives consumers pre-parsed versions of message events", () => { + let publisher!: MockPublisher; + const oneWay = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket, pub] = createMockWebSocket(url, protocols); + publisher = pub; + return socket; + }, + }); + + const onMessage = jest.fn(); + oneWay.addEventListener("message", onMessage); + + const payload = { + value: 5, + cool: "yes", + }; + const event = new MessageEvent("message", { + data: JSON.stringify(payload), + }); + + publisher.publishMessage(event); + expect(onMessage).toHaveBeenCalledWith({ + sourceEvent: event, + parsedMessage: payload, + parseError: undefined, + }); + }); + + it("Exposes parsing error if message payload could not be parsed as JSON", () => { + let publisher!: MockPublisher; + const oneWay = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket, pub] = createMockWebSocket(url, protocols); + publisher = pub; + return socket; + }, + }); + + const onMessage = jest.fn(); + oneWay.addEventListener("message", onMessage); + + const payload = "definitely not valid JSON"; + const event = new MessageEvent("message", { + data: payload, + }); + publisher.publishMessage(event); + + const arg: OneWayMessageEvent = onMessage.mock.lastCall[0]; + expect(arg.sourceEvent).toEqual(event); + expect(arg.parsedMessage).toEqual(undefined); + expect(arg.parseError).toBeInstanceOf(Error); + }); + + it("Passes all search param values through Websocket URL", () => { + const input1: Record = { + cool: "yeah", + yeah: "cool", + blah: "5", + }; + const oneWay1 = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket] = createMockWebSocket(url, protocols); + return socket; + }, + searchParams: input1, + location: { + protocol: "https:", + host: "www.blah.com", + }, + }); + let [base, params] = oneWay1.url.split("?"); + expect(base).toBe("wss://www.blah.com/api/v2/blah"); + for (const [key, value] of Object.entries(input1)) { + expect(params).toContain(`${key}=${value}`); + } + + const input2 = new URLSearchParams(input1); + const oneWay2 = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket] = createMockWebSocket(url, protocols); + return socket; + }, + searchParams: input2, + location: { + protocol: "https:", + host: "www.blah.com", + }, + }); + [base, params] = oneWay2.url.split("?"); + expect(base).toBe("wss://www.blah.com/api/v2/blah"); + for (const [key, value] of Object.entries(input2)) { + expect(params).toContain(`${key}=${value}`); + } + + const oneWay3 = new OneWayWebSocket({ + apiRoute: dummyRoute, + websocketInit: (url, protocols) => { + const [socket] = createMockWebSocket(url, protocols); + return socket; + }, + searchParams: undefined, + location: { + protocol: "https:", + host: "www.blah.com", + }, + }); + [base, params] = oneWay3.url.split("?"); + expect(base).toBe("wss://www.blah.com/api/v2/blah"); + expect(params).toBe(undefined); + }); +}); diff --git a/site/src/utils/OneWayWebSocket.ts b/site/src/utils/OneWayWebSocket.ts new file mode 100644 index 0000000000000..94ed1f1efc868 --- /dev/null +++ b/site/src/utils/OneWayWebSocket.ts @@ -0,0 +1,198 @@ +/** + * @file A wrapper over WebSockets that (1) enforces one-way communication, and + * (2) supports automatically parsing JSON messages as they come in. + * + * This should ALWAYS be favored in favor of using Server-Sent Events and the + * built-in EventSource class for doing one-way communication. SSEs have a hard + * limitation on HTTP/1.1 and below where there is a maximum number of 6 ports + * that can ever be used for a domain (sometimes less depending on the browser). + * Not only is this limit shared with short-lived REST requests, but it also + * applies across tabs and windows. So if a user opens Coder in multiple tabs, + * there is a very real possibility that parts of the app will start to lock up + * without it being clear why. + * + * WebSockets do not have this limitation, even on HTTP/1.1 – all modern + * browsers implement at least some degree of multiplexing for them. + */ + +// Not bothering with trying to borrow methods from the base WebSocket type +// because it's already a mess of inheritance and generics, and we're going to +// have to add a few more +export type WebSocketEventType = "close" | "error" | "message" | "open"; + +export type OneWayMessageEvent = Readonly< + | { + sourceEvent: MessageEvent; + parsedMessage: TData; + parseError: undefined; + } + | { + sourceEvent: MessageEvent; + parsedMessage: undefined; + parseError: Error; + } +>; + +type OneWayEventPayloadMap = { + close: CloseEvent; + error: Event; + message: OneWayMessageEvent; + open: Event; +}; + +type WebSocketMessageCallback = (payload: MessageEvent) => void; + +type OneWayEventCallback = ( + payload: OneWayEventPayloadMap[TEvent], +) => void; + +interface OneWayWebSocketApi { + get url(): string; + + addEventListener: ( + eventType: TEvent, + callback: OneWayEventCallback, + ) => void; + + removeEventListener: ( + eventType: TEvent, + callback: OneWayEventCallback, + ) => void; + + close: (closeCode?: number, reason?: string) => void; +} + +type OneWayWebSocketInit = Readonly<{ + apiRoute: string; + serverProtocols?: string | string[]; + searchParams?: Record | URLSearchParams; + binaryType?: BinaryType; + websocketInit?: (url: string, protocols?: string | string[]) => WebSocket; + location?: Readonly<{ + protocol: string; + host: string; + }>; +}>; + +function defaultInit(url: string, protocols?: string | string[]): WebSocket { + return new WebSocket(url, protocols); +} + +export class OneWayWebSocket + implements OneWayWebSocketApi +{ + readonly #socket: WebSocket; + readonly #messageCallbackWrappers = new Map< + OneWayEventCallback, + WebSocketMessageCallback + >(); + + constructor(init: OneWayWebSocketInit) { + const { + apiRoute, + searchParams, + serverProtocols, + binaryType = "blob", + location = window.location, + websocketInit = defaultInit, + } = init; + + if (!apiRoute.startsWith("/api/v2/")) { + throw new Error(`API route '${apiRoute}' does not begin with a slash`); + } + + const formattedParams = + searchParams instanceof URLSearchParams + ? searchParams + : new URLSearchParams(searchParams); + const paramsString = formattedParams.toString(); + const paramsSuffix = paramsString ? `?${paramsString}` : ""; + const wsProtocol = location.protocol === "https:" ? "wss:" : "ws:"; + const url = `${wsProtocol}//${location.host}${apiRoute}${paramsSuffix}`; + + this.#socket = websocketInit(url, serverProtocols); + this.#socket.binaryType = binaryType; + } + + get url(): string { + return this.#socket.url; + } + + addEventListener( + event: TEvent, + callback: OneWayEventCallback, + ): void { + // Not happy about all the type assertions, but there are some nasty + // type contravariance issues if you try to resolve the function types + // properly. This is actually the lesser of two evils + const looseCallback = callback as OneWayEventCallback< + TData, + WebSocketEventType + >; + + if (this.#messageCallbackWrappers.has(looseCallback)) { + return; + } + if (event !== "message") { + this.#socket.addEventListener(event, looseCallback); + return; + } + + const wrapped = (event: MessageEvent): void => { + const messageCallback = looseCallback as OneWayEventCallback< + TData, + "message" + >; + + try { + const message = JSON.parse(event.data) as TData; + messageCallback({ + sourceEvent: event, + parseError: undefined, + parsedMessage: message, + }); + } catch (err) { + messageCallback({ + sourceEvent: event, + parseError: err as Error, + parsedMessage: undefined, + }); + } + }; + + this.#socket.addEventListener(event as "message", wrapped); + this.#messageCallbackWrappers.set(looseCallback, wrapped); + } + + removeEventListener( + event: TEvent, + callback: OneWayEventCallback, + ): void { + const looseCallback = callback as OneWayEventCallback< + TData, + WebSocketEventType + >; + + if (event !== "message") { + this.#socket.removeEventListener(event, looseCallback); + return; + } + if (!this.#messageCallbackWrappers.has(looseCallback)) { + return; + } + + const wrapper = this.#messageCallbackWrappers.get(looseCallback); + if (wrapper === undefined) { + throw new Error( + `Cannot unregister callback for event ${event}. This is likely an issue with the browser itself.`, + ); + } + + this.#socket.removeEventListener(event as "message", wrapper); + this.#messageCallbackWrappers.delete(looseCallback); + } + + close(closeCode?: number, reason?: string): void { + this.#socket.close(closeCode, reason); + } +}