Skip to content

Commit ff1df69

Browse files
committed
fix: mock Agent querying OS for listening ports in tests
1 parent 48b8e22 commit ff1df69

File tree

6 files changed

+219
-291
lines changed

6 files changed

+219
-291
lines changed

agent/agent.go

Lines changed: 56 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"hash/fnv"
1010
"io"
11+
"maps"
1112
"net"
1213
"net/http"
1314
"net/netip"
@@ -70,16 +71,21 @@ const (
7071
)
7172

7273
type Options struct {
73-
Filesystem afero.Fs
74-
LogDir string
75-
TempDir string
76-
ScriptDataDir string
77-
Client Client
78-
ReconnectingPTYTimeout time.Duration
79-
EnvironmentVariables map[string]string
80-
Logger slog.Logger
81-
IgnorePorts map[int]string
82-
PortCacheDuration time.Duration
74+
Filesystem afero.Fs
75+
LogDir string
76+
TempDir string
77+
ScriptDataDir string
78+
Client Client
79+
ReconnectingPTYTimeout time.Duration
80+
EnvironmentVariables map[string]string
81+
Logger slog.Logger
82+
// IgnorePorts tells the api handler which ports to ignore when
83+
// listing all listening ports. This is helpful to hide ports that
84+
// are used by the agent, that the user does not care about.
85+
IgnorePorts map[int]string
86+
// ListeningPortsGetter is used to get the list of listening ports. Only
87+
// tests should set this. If unset, a default that queries the OS will be used.
88+
ListeningPortsGetter ListeningPortsGetter
8389
SSHMaxTimeout time.Duration
8490
TailnetListenPort uint16
8591
Subsystems []codersdk.AgentSubsystem
@@ -137,9 +143,7 @@ func New(options Options) Agent {
137143
if options.ServiceBannerRefreshInterval == 0 {
138144
options.ServiceBannerRefreshInterval = 2 * time.Minute
139145
}
140-
if options.PortCacheDuration == 0 {
141-
options.PortCacheDuration = 1 * time.Second
142-
}
146+
143147
if options.Clock == nil {
144148
options.Clock = quartz.NewReal()
145149
}
@@ -153,30 +157,38 @@ func New(options Options) Agent {
153157
options.Execer = agentexec.DefaultExecer
154158
}
155159

160+
if options.ListeningPortsGetter == nil {
161+
options.ListeningPortsGetter = &osListeningPortsGetter{
162+
cacheDuration: 1 * time.Second,
163+
}
164+
}
165+
156166
hardCtx, hardCancel := context.WithCancel(context.Background())
157167
gracefulCtx, gracefulCancel := context.WithCancel(hardCtx)
158168
a := &agent{
159-
clock: options.Clock,
160-
tailnetListenPort: options.TailnetListenPort,
161-
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
162-
logger: options.Logger,
163-
gracefulCtx: gracefulCtx,
164-
gracefulCancel: gracefulCancel,
165-
hardCtx: hardCtx,
166-
hardCancel: hardCancel,
167-
coordDisconnected: make(chan struct{}),
168-
environmentVariables: options.EnvironmentVariables,
169-
client: options.Client,
170-
filesystem: options.Filesystem,
171-
logDir: options.LogDir,
172-
tempDir: options.TempDir,
173-
scriptDataDir: options.ScriptDataDir,
174-
lifecycleUpdate: make(chan struct{}, 1),
175-
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
176-
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
177-
reportConnectionsUpdate: make(chan struct{}, 1),
178-
ignorePorts: options.IgnorePorts,
179-
portCacheDuration: options.PortCacheDuration,
169+
clock: options.Clock,
170+
tailnetListenPort: options.TailnetListenPort,
171+
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
172+
logger: options.Logger,
173+
gracefulCtx: gracefulCtx,
174+
gracefulCancel: gracefulCancel,
175+
hardCtx: hardCtx,
176+
hardCancel: hardCancel,
177+
coordDisconnected: make(chan struct{}),
178+
environmentVariables: options.EnvironmentVariables,
179+
client: options.Client,
180+
filesystem: options.Filesystem,
181+
logDir: options.LogDir,
182+
tempDir: options.TempDir,
183+
scriptDataDir: options.ScriptDataDir,
184+
lifecycleUpdate: make(chan struct{}, 1),
185+
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
186+
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
187+
reportConnectionsUpdate: make(chan struct{}, 1),
188+
listeningPortsHandler: listeningPortsHandler{
189+
getter: options.ListeningPortsGetter,
190+
ignorePorts: maps.Clone(options.IgnorePorts),
191+
},
180192
reportMetadataInterval: options.ReportMetadataInterval,
181193
announcementBannersRefreshInterval: options.ServiceBannerRefreshInterval,
182194
sshMaxTimeout: options.SSHMaxTimeout,
@@ -202,20 +214,16 @@ func New(options Options) Agent {
202214
}
203215

204216
type agent struct {
205-
clock quartz.Clock
206-
logger slog.Logger
207-
client Client
208-
tailnetListenPort uint16
209-
filesystem afero.Fs
210-
logDir string
211-
tempDir string
212-
scriptDataDir string
213-
// ignorePorts tells the api handler which ports to ignore when
214-
// listing all listening ports. This is helpful to hide ports that
215-
// are used by the agent, that the user does not care about.
216-
ignorePorts map[int]string
217-
portCacheDuration time.Duration
218-
subsystems []codersdk.AgentSubsystem
217+
clock quartz.Clock
218+
logger slog.Logger
219+
client Client
220+
tailnetListenPort uint16
221+
filesystem afero.Fs
222+
logDir string
223+
tempDir string
224+
scriptDataDir string
225+
listeningPortsHandler listeningPortsHandler
226+
subsystems []codersdk.AgentSubsystem
219227

220228
reconnectingPTYTimeout time.Duration
221229
reconnectingPTYServer *reconnectingpty.Server

agent/api.go

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@ package agent
22

33
import (
44
"net/http"
5-
"sync"
6-
"time"
75

86
"github.com/go-chi/chi/v5"
97
"github.com/google/uuid"
108

119
"github.com/coder/coder/v2/coderd/httpapi"
1210
"github.com/coder/coder/v2/codersdk"
11+
"github.com/coder/coder/v2/codersdk/workspacesdk"
1312
)
1413

1514
func (a *agent) apiHandler() http.Handler {
@@ -20,23 +19,6 @@ func (a *agent) apiHandler() http.Handler {
2019
})
2120
})
2221

23-
// Make a copy to ensure the map is not modified after the handler is
24-
// created.
25-
cpy := make(map[int]string)
26-
for k, b := range a.ignorePorts {
27-
cpy[k] = b
28-
}
29-
30-
cacheDuration := 1 * time.Second
31-
if a.portCacheDuration > 0 {
32-
cacheDuration = a.portCacheDuration
33-
}
34-
35-
lp := &listeningPortsHandler{
36-
ignorePorts: cpy,
37-
cacheDuration: cacheDuration,
38-
}
39-
4022
if a.devcontainers {
4123
r.Mount("/api/v0/containers", a.containerAPI.Routes())
4224
} else if manifest := a.manifest.Load(); manifest != nil && manifest.ParentID != uuid.Nil {
@@ -57,7 +39,7 @@ func (a *agent) apiHandler() http.Handler {
5739

5840
promHandler := PrometheusMetricsHandler(a.prometheusRegistry, a.logger)
5941

60-
r.Get("/api/v0/listening-ports", lp.handler)
42+
r.Get("/api/v0/listening-ports", a.listeningPortsHandler.handler)
6143
r.Get("/api/v0/netcheck", a.HandleNetcheck)
6244
r.Post("/api/v0/list-directory", a.HandleLS)
6345
r.Get("/api/v0/read-file", a.HandleReadFile)
@@ -72,22 +54,21 @@ func (a *agent) apiHandler() http.Handler {
7254
return r
7355
}
7456

75-
type listeningPortsHandler struct {
76-
ignorePorts map[int]string
77-
cacheDuration time.Duration
57+
type ListeningPortsGetter interface {
58+
GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error)
59+
}
7860

79-
//nolint: unused // used on some but not all platforms
80-
mut sync.Mutex
81-
//nolint: unused // used on some but not all platforms
82-
ports []codersdk.WorkspaceAgentListeningPort
83-
//nolint: unused // used on some but not all platforms
84-
mtime time.Time
61+
type listeningPortsHandler struct {
62+
// In production code, this is set to an osListeningPortsGetter, but it can be overridden for
63+
// testing.
64+
getter ListeningPortsGetter
65+
ignorePorts map[int]string
8566
}
8667

8768
// handler returns a list of listening ports. This is tested by coderd's
8869
// TestWorkspaceAgentListeningPorts test.
8970
func (lp *listeningPortsHandler) handler(rw http.ResponseWriter, r *http.Request) {
90-
ports, err := lp.getListeningPorts()
71+
ports, err := lp.getter.GetListeningPorts()
9172
if err != nil {
9273
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
9374
Message: "Could not scan for listening ports.",
@@ -96,7 +77,20 @@ func (lp *listeningPortsHandler) handler(rw http.ResponseWriter, r *http.Request
9677
return
9778
}
9879

80+
filteredPorts := make([]codersdk.WorkspaceAgentListeningPort, 0, len(ports))
81+
for _, port := range ports {
82+
if port.Port < workspacesdk.AgentMinimumListeningPort {
83+
continue
84+
}
85+
86+
// Ignore ports that we've been told to ignore.
87+
if _, ok := lp.ignorePorts[int(port.Port)]; ok {
88+
continue
89+
}
90+
filteredPorts = append(filteredPorts, port)
91+
}
92+
9993
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.WorkspaceAgentListeningPortsResponse{
100-
Ports: ports,
94+
Ports: filteredPorts,
10195
})
10296
}

agent/ports_supported.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,23 @@
33
package agent
44

55
import (
6+
"sync"
67
"time"
78

89
"github.com/cakturk/go-netstat/netstat"
910
"golang.org/x/xerrors"
1011

1112
"github.com/coder/coder/v2/codersdk"
12-
"github.com/coder/coder/v2/codersdk/workspacesdk"
1313
)
1414

15-
func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
15+
type osListeningPortsGetter struct {
16+
cacheDuration time.Duration
17+
mut sync.Mutex
18+
ports []codersdk.WorkspaceAgentListeningPort
19+
mtime time.Time
20+
}
21+
22+
func (lp *osListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
1623
lp.mut.Lock()
1724
defer lp.mut.Unlock()
1825

@@ -33,12 +40,7 @@ func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentL
3340
seen := make(map[uint16]struct{}, len(tabs))
3441
ports := []codersdk.WorkspaceAgentListeningPort{}
3542
for _, tab := range tabs {
36-
if tab.LocalAddr == nil || tab.LocalAddr.Port < workspacesdk.AgentMinimumListeningPort {
37-
continue
38-
}
39-
40-
// Ignore ports that we've been told to ignore.
41-
if _, ok := lp.ignorePorts[int(tab.LocalAddr.Port)]; ok {
43+
if tab.LocalAddr == nil {
4244
continue
4345
}
4446

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//go:build linux || (windows && amd64)
2+
3+
package agent
4+
5+
import (
6+
"net"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestOSListeningPortsGetter(t *testing.T) {
14+
t.Parallel()
15+
16+
uut := &osListeningPortsGetter{
17+
cacheDuration: 1 * time.Hour,
18+
}
19+
20+
l, err := net.Listen("tcp", "localhost:0")
21+
require.NoError(t, err)
22+
defer l.Close()
23+
24+
ports, err := uut.GetListeningPorts()
25+
require.NoError(t, err)
26+
found := false
27+
for _, port := range ports {
28+
// #nosec G115 - Safe conversion as TCP port numbers are within uint16 range (0-65535)
29+
if port.Port == uint16(l.Addr().(*net.TCPAddr).Port) {
30+
found = true
31+
break
32+
}
33+
}
34+
require.True(t, found)
35+
36+
// check that we cache the ports
37+
err = l.Close()
38+
require.NoError(t, err)
39+
portsNew, err := uut.GetListeningPorts()
40+
require.NoError(t, err)
41+
require.Equal(t, ports, portsNew)
42+
43+
// note that it's unsafe to try to assert that a port does not exist in the response
44+
// because the OS may reallocate the port very quickly.
45+
}

agent/ports_unsupported.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,17 @@
22

33
package agent
44

5-
import "github.com/coder/coder/v2/codersdk"
5+
import (
6+
"time"
67

7-
func (*listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
8+
"github.com/coder/coder/v2/codersdk"
9+
)
10+
11+
type osListeningPortsGetter struct {
12+
cacheDuration time.Duration
13+
}
14+
15+
func (*osListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
816
// Can't scan for ports on non-linux or non-windows_amd64 systems at the
917
// moment. The UI will not show any "no ports found" message to the user, so
1018
// the user won't suspect a thing.

0 commit comments

Comments
 (0)