Skip to content

Commit afd4043

Browse files
authored
fix: mock Agent querying OS for listening ports in tests (#20842)
fixes coder/internal#1123 We want to tests that ports are not included after they are no longer used, but this isn't safe on the real OS networking stack because there is no way to guarantee a port _won't_ be used. Instead, we introduce an interface and fake implementation for testing. On order to leave the filtering logic in the test path, this PR also does some refactoring. Caching logic is left in the real OS querying implementation and a new test case is added for it in this PR.
1 parent 823009d commit afd4043

File tree

6 files changed

+219
-291
lines changed

6 files changed

+219
-291
lines changed

agent/agent.go

Lines changed: 56 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"hash/fnv"
1010
"io"
11+
"maps"
1112
"net"
1213
"net/http"
1314
"net/netip"
@@ -70,16 +71,21 @@ const (
7071
)
7172

7273
type Options struct {
73-
Filesystem afero.Fs
74-
LogDir string
75-
TempDir string
76-
ScriptDataDir string
77-
Client Client
78-
ReconnectingPTYTimeout time.Duration
79-
EnvironmentVariables map[string]string
80-
Logger slog.Logger
81-
IgnorePorts map[int]string
82-
PortCacheDuration time.Duration
74+
Filesystem afero.Fs
75+
LogDir string
76+
TempDir string
77+
ScriptDataDir string
78+
Client Client
79+
ReconnectingPTYTimeout time.Duration
80+
EnvironmentVariables map[string]string
81+
Logger slog.Logger
82+
// IgnorePorts tells the api handler which ports to ignore when
83+
// listing all listening ports. This is helpful to hide ports that
84+
// are used by the agent, that the user does not care about.
85+
IgnorePorts map[int]string
86+
// ListeningPortsGetter is used to get the list of listening ports. Only
87+
// tests should set this. If unset, a default that queries the OS will be used.
88+
ListeningPortsGetter ListeningPortsGetter
8389
SSHMaxTimeout time.Duration
8490
TailnetListenPort uint16
8591
Subsystems []codersdk.AgentSubsystem
@@ -137,9 +143,7 @@ func New(options Options) Agent {
137143
if options.ServiceBannerRefreshInterval == 0 {
138144
options.ServiceBannerRefreshInterval = 2 * time.Minute
139145
}
140-
if options.PortCacheDuration == 0 {
141-
options.PortCacheDuration = 1 * time.Second
142-
}
146+
143147
if options.Clock == nil {
144148
options.Clock = quartz.NewReal()
145149
}
@@ -153,30 +157,38 @@ func New(options Options) Agent {
153157
options.Execer = agentexec.DefaultExecer
154158
}
155159

160+
if options.ListeningPortsGetter == nil {
161+
options.ListeningPortsGetter = &osListeningPortsGetter{
162+
cacheDuration: 1 * time.Second,
163+
}
164+
}
165+
156166
hardCtx, hardCancel := context.WithCancel(context.Background())
157167
gracefulCtx, gracefulCancel := context.WithCancel(hardCtx)
158168
a := &agent{
159-
clock: options.Clock,
160-
tailnetListenPort: options.TailnetListenPort,
161-
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
162-
logger: options.Logger,
163-
gracefulCtx: gracefulCtx,
164-
gracefulCancel: gracefulCancel,
165-
hardCtx: hardCtx,
166-
hardCancel: hardCancel,
167-
coordDisconnected: make(chan struct{}),
168-
environmentVariables: options.EnvironmentVariables,
169-
client: options.Client,
170-
filesystem: options.Filesystem,
171-
logDir: options.LogDir,
172-
tempDir: options.TempDir,
173-
scriptDataDir: options.ScriptDataDir,
174-
lifecycleUpdate: make(chan struct{}, 1),
175-
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
176-
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
177-
reportConnectionsUpdate: make(chan struct{}, 1),
178-
ignorePorts: options.IgnorePorts,
179-
portCacheDuration: options.PortCacheDuration,
169+
clock: options.Clock,
170+
tailnetListenPort: options.TailnetListenPort,
171+
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
172+
logger: options.Logger,
173+
gracefulCtx: gracefulCtx,
174+
gracefulCancel: gracefulCancel,
175+
hardCtx: hardCtx,
176+
hardCancel: hardCancel,
177+
coordDisconnected: make(chan struct{}),
178+
environmentVariables: options.EnvironmentVariables,
179+
client: options.Client,
180+
filesystem: options.Filesystem,
181+
logDir: options.LogDir,
182+
tempDir: options.TempDir,
183+
scriptDataDir: options.ScriptDataDir,
184+
lifecycleUpdate: make(chan struct{}, 1),
185+
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
186+
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
187+
reportConnectionsUpdate: make(chan struct{}, 1),
188+
listeningPortsHandler: listeningPortsHandler{
189+
getter: options.ListeningPortsGetter,
190+
ignorePorts: maps.Clone(options.IgnorePorts),
191+
},
180192
reportMetadataInterval: options.ReportMetadataInterval,
181193
announcementBannersRefreshInterval: options.ServiceBannerRefreshInterval,
182194
sshMaxTimeout: options.SSHMaxTimeout,
@@ -202,20 +214,16 @@ func New(options Options) Agent {
202214
}
203215

204216
type agent struct {
205-
clock quartz.Clock
206-
logger slog.Logger
207-
client Client
208-
tailnetListenPort uint16
209-
filesystem afero.Fs
210-
logDir string
211-
tempDir string
212-
scriptDataDir string
213-
// ignorePorts tells the api handler which ports to ignore when
214-
// listing all listening ports. This is helpful to hide ports that
215-
// are used by the agent, that the user does not care about.
216-
ignorePorts map[int]string
217-
portCacheDuration time.Duration
218-
subsystems []codersdk.AgentSubsystem
217+
clock quartz.Clock
218+
logger slog.Logger
219+
client Client
220+
tailnetListenPort uint16
221+
filesystem afero.Fs
222+
logDir string
223+
tempDir string
224+
scriptDataDir string
225+
listeningPortsHandler listeningPortsHandler
226+
subsystems []codersdk.AgentSubsystem
219227

220228
reconnectingPTYTimeout time.Duration
221229
reconnectingPTYServer *reconnectingpty.Server

agent/api.go

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@ package agent
22

33
import (
44
"net/http"
5-
"sync"
6-
"time"
75

86
"github.com/go-chi/chi/v5"
97
"github.com/google/uuid"
108

119
"github.com/coder/coder/v2/coderd/httpapi"
1210
"github.com/coder/coder/v2/codersdk"
11+
"github.com/coder/coder/v2/codersdk/workspacesdk"
1312
)
1413

1514
func (a *agent) apiHandler() http.Handler {
@@ -20,23 +19,6 @@ func (a *agent) apiHandler() http.Handler {
2019
})
2120
})
2221

23-
// Make a copy to ensure the map is not modified after the handler is
24-
// created.
25-
cpy := make(map[int]string)
26-
for k, b := range a.ignorePorts {
27-
cpy[k] = b
28-
}
29-
30-
cacheDuration := 1 * time.Second
31-
if a.portCacheDuration > 0 {
32-
cacheDuration = a.portCacheDuration
33-
}
34-
35-
lp := &listeningPortsHandler{
36-
ignorePorts: cpy,
37-
cacheDuration: cacheDuration,
38-
}
39-
4022
if a.devcontainers {
4123
r.Mount("/api/v0/containers", a.containerAPI.Routes())
4224
} else if manifest := a.manifest.Load(); manifest != nil && manifest.ParentID != uuid.Nil {
@@ -57,7 +39,7 @@ func (a *agent) apiHandler() http.Handler {
5739

5840
promHandler := PrometheusMetricsHandler(a.prometheusRegistry, a.logger)
5941

60-
r.Get("/api/v0/listening-ports", lp.handler)
42+
r.Get("/api/v0/listening-ports", a.listeningPortsHandler.handler)
6143
r.Get("/api/v0/netcheck", a.HandleNetcheck)
6244
r.Post("/api/v0/list-directory", a.HandleLS)
6345
r.Get("/api/v0/read-file", a.HandleReadFile)
@@ -72,22 +54,21 @@ func (a *agent) apiHandler() http.Handler {
7254
return r
7355
}
7456

75-
type listeningPortsHandler struct {
76-
ignorePorts map[int]string
77-
cacheDuration time.Duration
57+
type ListeningPortsGetter interface {
58+
GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error)
59+
}
7860

79-
//nolint: unused // used on some but not all platforms
80-
mut sync.Mutex
81-
//nolint: unused // used on some but not all platforms
82-
ports []codersdk.WorkspaceAgentListeningPort
83-
//nolint: unused // used on some but not all platforms
84-
mtime time.Time
61+
type listeningPortsHandler struct {
62+
// In production code, this is set to an osListeningPortsGetter, but it can be overridden for
63+
// testing.
64+
getter ListeningPortsGetter
65+
ignorePorts map[int]string
8566
}
8667

8768
// handler returns a list of listening ports. This is tested by coderd's
8869
// TestWorkspaceAgentListeningPorts test.
8970
func (lp *listeningPortsHandler) handler(rw http.ResponseWriter, r *http.Request) {
90-
ports, err := lp.getListeningPorts()
71+
ports, err := lp.getter.GetListeningPorts()
9172
if err != nil {
9273
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
9374
Message: "Could not scan for listening ports.",
@@ -96,7 +77,20 @@ func (lp *listeningPortsHandler) handler(rw http.ResponseWriter, r *http.Request
9677
return
9778
}
9879

80+
filteredPorts := make([]codersdk.WorkspaceAgentListeningPort, 0, len(ports))
81+
for _, port := range ports {
82+
if port.Port < workspacesdk.AgentMinimumListeningPort {
83+
continue
84+
}
85+
86+
// Ignore ports that we've been told to ignore.
87+
if _, ok := lp.ignorePorts[int(port.Port)]; ok {
88+
continue
89+
}
90+
filteredPorts = append(filteredPorts, port)
91+
}
92+
9993
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.WorkspaceAgentListeningPortsResponse{
100-
Ports: ports,
94+
Ports: filteredPorts,
10195
})
10296
}

agent/ports_supported.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,23 @@
33
package agent
44

55
import (
6+
"sync"
67
"time"
78

89
"github.com/cakturk/go-netstat/netstat"
910
"golang.org/x/xerrors"
1011

1112
"github.com/coder/coder/v2/codersdk"
12-
"github.com/coder/coder/v2/codersdk/workspacesdk"
1313
)
1414

15-
func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
15+
type osListeningPortsGetter struct {
16+
cacheDuration time.Duration
17+
mut sync.Mutex
18+
ports []codersdk.WorkspaceAgentListeningPort
19+
mtime time.Time
20+
}
21+
22+
func (lp *osListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
1623
lp.mut.Lock()
1724
defer lp.mut.Unlock()
1825

@@ -33,12 +40,7 @@ func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentL
3340
seen := make(map[uint16]struct{}, len(tabs))
3441
ports := []codersdk.WorkspaceAgentListeningPort{}
3542
for _, tab := range tabs {
36-
if tab.LocalAddr == nil || tab.LocalAddr.Port < workspacesdk.AgentMinimumListeningPort {
37-
continue
38-
}
39-
40-
// Ignore ports that we've been told to ignore.
41-
if _, ok := lp.ignorePorts[int(tab.LocalAddr.Port)]; ok {
43+
if tab.LocalAddr == nil {
4244
continue
4345
}
4446

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//go:build linux || (windows && amd64)
2+
3+
package agent
4+
5+
import (
6+
"net"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestOSListeningPortsGetter(t *testing.T) {
14+
t.Parallel()
15+
16+
uut := &osListeningPortsGetter{
17+
cacheDuration: 1 * time.Hour,
18+
}
19+
20+
l, err := net.Listen("tcp", "localhost:0")
21+
require.NoError(t, err)
22+
defer l.Close()
23+
24+
ports, err := uut.GetListeningPorts()
25+
require.NoError(t, err)
26+
found := false
27+
for _, port := range ports {
28+
// #nosec G115 - Safe conversion as TCP port numbers are within uint16 range (0-65535)
29+
if port.Port == uint16(l.Addr().(*net.TCPAddr).Port) {
30+
found = true
31+
break
32+
}
33+
}
34+
require.True(t, found)
35+
36+
// check that we cache the ports
37+
err = l.Close()
38+
require.NoError(t, err)
39+
portsNew, err := uut.GetListeningPorts()
40+
require.NoError(t, err)
41+
require.Equal(t, ports, portsNew)
42+
43+
// note that it's unsafe to try to assert that a port does not exist in the response
44+
// because the OS may reallocate the port very quickly.
45+
}

agent/ports_unsupported.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,17 @@
22

33
package agent
44

5-
import "github.com/coder/coder/v2/codersdk"
5+
import (
6+
"time"
67

7-
func (*listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
8+
"github.com/coder/coder/v2/codersdk"
9+
)
10+
11+
type osListeningPortsGetter struct {
12+
cacheDuration time.Duration
13+
}
14+
15+
func (*osListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
816
// Can't scan for ports on non-linux or non-windows_amd64 systems at the
917
// moment. The UI will not show any "no ports found" message to the user, so
1018
// the user won't suspect a thing.

0 commit comments

Comments
 (0)