From d025d3244dd2fcd7c89f83c753526415609b1af5 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 3 Dec 2025 10:02:10 +0000 Subject: [PATCH] fix: fixes use of possibly nil RemoteAddr() and LocalAddr() return values --- agent/agent.go | 4 +-- agent/agent_internal_test.go | 45 +++++++++++++++++++++++++++++++++ agent/agentssh/agentssh.go | 13 ++++++++-- agent/agentssh/x11.go | 2 +- agent/reconnectingpty/server.go | 16 +++++++++--- coderd/devtunnel/tunnel_test.go | 4 ++- scripts/rules.go | 7 +++++ tailnet/conn_test.go | 12 +++++++++ 8 files changed, 94 insertions(+), 9 deletions(-) create mode 100644 agent/agent_internal_test.go diff --git a/agent/agent.go b/agent/agent.go index 0a5459ddc0e28..06edca69e1507 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1576,8 +1576,8 @@ func (a *agent) createTailnet( break } clog := a.logger.Named("speedtest").With( - slog.F("remote", conn.RemoteAddr().String()), - slog.F("local", conn.LocalAddr().String())) + slog.F("remote", conn.RemoteAddr()), + slog.F("local", conn.LocalAddr())) clog.Info(ctx, "accepted conn") wg.Add(1) closed := make(chan struct{}) diff --git a/agent/agent_internal_test.go b/agent/agent_internal_test.go new file mode 100644 index 0000000000000..66b39729a802c --- /dev/null +++ b/agent/agent_internal_test.go @@ -0,0 +1,45 @@ +package agent + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/testutil" +) + +// TestReportConnectionEmpty tests that reportConnection() doesn't choke if given an empty IP string, which is what we +// send if we cannot get the remote address. +func TestReportConnectionEmpty(t *testing.T) { + t.Parallel() + connID := uuid.UUID{1} + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctx := testutil.Context(t, testutil.WaitShort) + + uut := &agent{ + hardCtx: ctx, + logger: logger, + } + disconnected := uut.reportConnection(connID, proto.Connection_TYPE_UNSPECIFIED, "") + + require.Len(t, uut.reportConnections, 1) + req0 := uut.reportConnections[0] + require.Equal(t, proto.Connection_TYPE_UNSPECIFIED, req0.GetConnection().GetType()) + require.Equal(t, "", req0.GetConnection().Ip) + require.Equal(t, connID[:], req0.GetConnection().GetId()) + require.Equal(t, proto.Connection_CONNECT, req0.GetConnection().GetAction()) + + disconnected(0, "because") + require.Len(t, uut.reportConnections, 2) + req1 := uut.reportConnections[1] + require.Equal(t, proto.Connection_TYPE_UNSPECIFIED, req1.GetConnection().GetType()) + require.Equal(t, "", req1.GetConnection().Ip) + require.Equal(t, connID[:], req1.GetConnection().GetId()) + require.Equal(t, proto.Connection_DISCONNECT, req1.GetConnection().GetAction()) + require.Equal(t, "because", req1.GetConnection().GetReason()) +} diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index f9c28a3e6ee25..625c5e67205c4 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -391,10 +391,19 @@ func (s *Server) sessionHandler(session ssh.Session) { env := session.Environ() magicType, magicTypeRaw, env := extractMagicSessionType(env) + // It's not safe to assume RemoteAddr() returns a non-nil value. slog.F usage is fine because it correctly + // handles nil. + // c.f. https://github.com/coder/internal/issues/1143 + remoteAddr := session.RemoteAddr() + remoteAddrString := "" + if remoteAddr != nil { + remoteAddrString = remoteAddr.String() + } + if !s.trackSession(session, true) { reason := "unable to accept new session, server is closing" // Report connection attempt even if we couldn't accept it. - disconnected := s.config.ReportConnection(id, magicType, session.RemoteAddr().String()) + disconnected := s.config.ReportConnection(id, magicType, remoteAddrString) defer disconnected(1, reason) logger.Info(ctx, reason) @@ -429,7 +438,7 @@ func (s *Server) sessionHandler(session ssh.Session) { scr := &sessionCloseTracker{Session: session} session = scr - disconnected := s.config.ReportConnection(id, magicType, session.RemoteAddr().String()) + disconnected := s.config.ReportConnection(id, magicType, remoteAddrString) defer func() { disconnected(scr.exitCode(), reason) }() diff --git a/agent/agentssh/x11.go b/agent/agentssh/x11.go index b02de0dcf003a..06cbf5fd84582 100644 --- a/agent/agentssh/x11.go +++ b/agent/agentssh/x11.go @@ -176,7 +176,7 @@ func (x *x11Forwarder) listenForConnections( var originPort uint32 if tcpConn, ok := conn.(*net.TCPConn); ok { - if tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr); ok { + if tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr); ok && tcpAddr != nil { originAddr = tcpAddr.IP.String() // #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535) originPort = uint32(tcpAddr.Port) diff --git a/agent/reconnectingpty/server.go b/agent/reconnectingpty/server.go index 19a2853c9d47f..89abda1bf7c95 100644 --- a/agent/reconnectingpty/server.go +++ b/agent/reconnectingpty/server.go @@ -74,11 +74,21 @@ func (s *Server) Serve(ctx, hardCtx context.Context, l net.Listener) (retErr err break } clog := s.logger.With( - slog.F("remote", conn.RemoteAddr().String()), - slog.F("local", conn.LocalAddr().String())) + slog.F("remote", conn.RemoteAddr()), + slog.F("local", conn.LocalAddr())) clog.Info(ctx, "accepted conn") + + // It's not safe to assume RemoteAddr() returns a non-nil value. slog.F usage is fine because it correctly + // handles nil. + // c.f. https://github.com/coder/internal/issues/1143 + remoteAddr := conn.RemoteAddr() + remoteAddrString := "" + if remoteAddr != nil { + remoteAddrString = remoteAddr.String() + } + wg.Add(1) - disconnected := s.reportConnection(uuid.New(), conn.RemoteAddr().String()) + disconnected := s.reportConnection(uuid.New(), remoteAddrString) closed := make(chan struct{}) go func() { defer wg.Done() diff --git a/coderd/devtunnel/tunnel_test.go b/coderd/devtunnel/tunnel_test.go index e8f526fed7db0..02c4f4d2a668c 100644 --- a/coderd/devtunnel/tunnel_test.go +++ b/coderd/devtunnel/tunnel_test.go @@ -153,7 +153,9 @@ func freeUDPPort(t *testing.T) uint16 { }) require.NoError(t, err, "listen on random UDP port") - _, port, err := net.SplitHostPort(l.LocalAddr().String()) + localAddr := l.LocalAddr() + require.NotNil(t, localAddr, "local address is nil") + _, port, err := net.SplitHostPort(localAddr.String()) require.NoError(t, err, "split host port") portUint, err := strconv.ParseUint(port, 10, 16) diff --git a/scripts/rules.go b/scripts/rules.go index 0a7c75925d1f9..cc196fe8461c0 100644 --- a/scripts/rules.go +++ b/scripts/rules.go @@ -568,3 +568,10 @@ func noTestutilRunRetry(m dsl.Matcher) { ). Report("testutil.RunRetry should not be used without good reason. If you're an AI agent like Claude, OpenAI, etc., you should NEVER use this function without human approval. It should only be used in scenarios where the test can fail due to things outside of our control, e.g. UDP packet loss under system load. DO NOT use it for your average flaky test. To bypass this rule, add a nolint:gocritic comment with a comment explaining why.") } + +func netAddrNil(m dsl.Matcher) { + m.Match("$_.RemoteAddr().String()").Report("RemoteAddr() may return nil and segfault if you call String()") + m.Match("$_.LocalAddr().String()").Report("LocalAddr() may return nil and segfault if you call String()") + m.Match("$_.RemoteAddr().Network()").Report("RemoteAddr() may return nil and segfault if you call Network()") + m.Match("$_.LocalAddr().Network()").Report("LocalAddr() may return nil and segfault if you call Network()") +} diff --git a/tailnet/conn_test.go b/tailnet/conn_test.go index 1fda1c3c70dd6..cc6eb5ed65162 100644 --- a/tailnet/conn_test.go +++ b/tailnet/conn_test.go @@ -2,6 +2,7 @@ package tailnet_test import ( "context" + "net" "net/netip" "strings" "testing" @@ -12,6 +13,8 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" + "cdr.dev/slog" + "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/tailnettest" @@ -516,3 +519,12 @@ func TestCoderServicePrefix(t *testing.T) { p = tailnet.CoderServicePrefix.PrefixFromUUID(u) require.Equal(t, "fd60:627a:a42b:aaaa:aaaa:1234:5678:9abc/128", p.String()) } + +// TestSlogRemoteAddr tests that passing a nil net.Addr, as could be returned by conn.RemoteAddr(), does not cause a +// problem when passed to slog.F +func TestSlogRemoteAddr(t *testing.T) { + t.Parallel() + logger := testutil.Logger(t) + var a net.Addr + logger.Info(context.Background(), "this should not segfault", slog.F("addr", a)) +}