Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 56 additions & 48 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"hash/fnv"
"io"
"maps"
"net"
"net/http"
"net/netip"
Expand Down Expand Up @@ -70,16 +71,21 @@ const (
)

type Options struct {
Filesystem afero.Fs
LogDir string
TempDir string
ScriptDataDir string
Client Client
ReconnectingPTYTimeout time.Duration
EnvironmentVariables map[string]string
Logger slog.Logger
IgnorePorts map[int]string
PortCacheDuration time.Duration
Filesystem afero.Fs
LogDir string
TempDir string
ScriptDataDir string
Client Client
ReconnectingPTYTimeout time.Duration
EnvironmentVariables map[string]string
Logger slog.Logger
// IgnorePorts tells the api handler which ports to ignore when
// listing all listening ports. This is helpful to hide ports that
// are used by the agent, that the user does not care about.
IgnorePorts map[int]string
// ListeningPortsGetter is used to get the list of listening ports. Only
// tests should set this. If unset, a default that queries the OS will be used.
ListeningPortsGetter ListeningPortsGetter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: These could be simplified a bit in naming by dropping "Getter". It makes these types somewhat verbose and limits future expansion of the type.

Hypothetical expansion: addition of ClearCache method.

A funny but legit alternative would be: type ListeningPortser interface { ListeningPorts() ... }.

SSHMaxTimeout time.Duration
TailnetListenPort uint16
Subsystems []codersdk.AgentSubsystem
Expand Down Expand Up @@ -137,9 +143,7 @@ func New(options Options) Agent {
if options.ServiceBannerRefreshInterval == 0 {
options.ServiceBannerRefreshInterval = 2 * time.Minute
}
if options.PortCacheDuration == 0 {
options.PortCacheDuration = 1 * time.Second
}

if options.Clock == nil {
options.Clock = quartz.NewReal()
}
Expand All @@ -153,30 +157,38 @@ func New(options Options) Agent {
options.Execer = agentexec.DefaultExecer
}

if options.ListeningPortsGetter == nil {
options.ListeningPortsGetter = &osListeningPortsGetter{
cacheDuration: 1 * time.Second,
}
}

hardCtx, hardCancel := context.WithCancel(context.Background())
gracefulCtx, gracefulCancel := context.WithCancel(hardCtx)
a := &agent{
clock: options.Clock,
tailnetListenPort: options.TailnetListenPort,
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
logger: options.Logger,
gracefulCtx: gracefulCtx,
gracefulCancel: gracefulCancel,
hardCtx: hardCtx,
hardCancel: hardCancel,
coordDisconnected: make(chan struct{}),
environmentVariables: options.EnvironmentVariables,
client: options.Client,
filesystem: options.Filesystem,
logDir: options.LogDir,
tempDir: options.TempDir,
scriptDataDir: options.ScriptDataDir,
lifecycleUpdate: make(chan struct{}, 1),
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
reportConnectionsUpdate: make(chan struct{}, 1),
ignorePorts: options.IgnorePorts,
portCacheDuration: options.PortCacheDuration,
clock: options.Clock,
tailnetListenPort: options.TailnetListenPort,
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
logger: options.Logger,
gracefulCtx: gracefulCtx,
gracefulCancel: gracefulCancel,
hardCtx: hardCtx,
hardCancel: hardCancel,
coordDisconnected: make(chan struct{}),
environmentVariables: options.EnvironmentVariables,
client: options.Client,
filesystem: options.Filesystem,
logDir: options.LogDir,
tempDir: options.TempDir,
scriptDataDir: options.ScriptDataDir,
lifecycleUpdate: make(chan struct{}, 1),
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
reportConnectionsUpdate: make(chan struct{}, 1),
listeningPortsHandler: listeningPortsHandler{
getter: options.ListeningPortsGetter,
ignorePorts: maps.Clone(options.IgnorePorts),
},
reportMetadataInterval: options.ReportMetadataInterval,
announcementBannersRefreshInterval: options.ServiceBannerRefreshInterval,
sshMaxTimeout: options.SSHMaxTimeout,
Expand All @@ -202,20 +214,16 @@ func New(options Options) Agent {
}

type agent struct {
clock quartz.Clock
logger slog.Logger
client Client
tailnetListenPort uint16
filesystem afero.Fs
logDir string
tempDir string
scriptDataDir string
// ignorePorts tells the api handler which ports to ignore when
// listing all listening ports. This is helpful to hide ports that
// are used by the agent, that the user does not care about.
ignorePorts map[int]string
portCacheDuration time.Duration
subsystems []codersdk.AgentSubsystem
clock quartz.Clock
logger slog.Logger
client Client
tailnetListenPort uint16
filesystem afero.Fs
logDir string
tempDir string
scriptDataDir string
listeningPortsHandler listeningPortsHandler
subsystems []codersdk.AgentSubsystem

reconnectingPTYTimeout time.Duration
reconnectingPTYServer *reconnectingpty.Server
Expand Down
56 changes: 25 additions & 31 deletions agent/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ package agent

import (
"net/http"
"sync"
"time"

"github.com/go-chi/chi/v5"
"github.com/google/uuid"

"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)

func (a *agent) apiHandler() http.Handler {
Expand All @@ -20,23 +19,6 @@ func (a *agent) apiHandler() http.Handler {
})
})

// Make a copy to ensure the map is not modified after the handler is
// created.
cpy := make(map[int]string)
for k, b := range a.ignorePorts {
cpy[k] = b
}

cacheDuration := 1 * time.Second
if a.portCacheDuration > 0 {
cacheDuration = a.portCacheDuration
}

lp := &listeningPortsHandler{
ignorePorts: cpy,
cacheDuration: cacheDuration,
}

if a.devcontainers {
r.Mount("/api/v0/containers", a.containerAPI.Routes())
} else if manifest := a.manifest.Load(); manifest != nil && manifest.ParentID != uuid.Nil {
Expand All @@ -57,7 +39,7 @@ func (a *agent) apiHandler() http.Handler {

promHandler := PrometheusMetricsHandler(a.prometheusRegistry, a.logger)

r.Get("/api/v0/listening-ports", lp.handler)
r.Get("/api/v0/listening-ports", a.listeningPortsHandler.handler)
r.Get("/api/v0/netcheck", a.HandleNetcheck)
r.Post("/api/v0/list-directory", a.HandleLS)
r.Get("/api/v0/read-file", a.HandleReadFile)
Expand All @@ -72,22 +54,21 @@ func (a *agent) apiHandler() http.Handler {
return r
}

type listeningPortsHandler struct {
ignorePorts map[int]string
cacheDuration time.Duration
type ListeningPortsGetter interface {
GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error)
}

//nolint: unused // used on some but not all platforms
mut sync.Mutex
//nolint: unused // used on some but not all platforms
ports []codersdk.WorkspaceAgentListeningPort
//nolint: unused // used on some but not all platforms
mtime time.Time
type listeningPortsHandler struct {
// In production code, this is set to an osListeningPortsGetter, but it can be overridden for
// testing.
getter ListeningPortsGetter
ignorePorts map[int]string
}

// handler returns a list of listening ports. This is tested by coderd's
// TestWorkspaceAgentListeningPorts test.
func (lp *listeningPortsHandler) handler(rw http.ResponseWriter, r *http.Request) {
ports, err := lp.getListeningPorts()
ports, err := lp.getter.GetListeningPorts()
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
Message: "Could not scan for listening ports.",
Expand All @@ -96,7 +77,20 @@ func (lp *listeningPortsHandler) handler(rw http.ResponseWriter, r *http.Request
return
}

filteredPorts := make([]codersdk.WorkspaceAgentListeningPort, 0, len(ports))
for _, port := range ports {
if port.Port < workspacesdk.AgentMinimumListeningPort {
continue
}

// Ignore ports that we've been told to ignore.
if _, ok := lp.ignorePorts[int(port.Port)]; ok {
continue
}
filteredPorts = append(filteredPorts, port)
}

httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.WorkspaceAgentListeningPortsResponse{
Ports: ports,
Ports: filteredPorts,
})
}
18 changes: 10 additions & 8 deletions agent/ports_supported.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,23 @@
package agent

import (
"sync"
"time"

"github.com/cakturk/go-netstat/netstat"
"golang.org/x/xerrors"

"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)

func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
type osListeningPortsGetter struct {
cacheDuration time.Duration
mut sync.Mutex
ports []codersdk.WorkspaceAgentListeningPort
mtime time.Time
}

func (lp *osListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
lp.mut.Lock()
defer lp.mut.Unlock()

Expand All @@ -33,12 +40,7 @@ func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentL
seen := make(map[uint16]struct{}, len(tabs))
ports := []codersdk.WorkspaceAgentListeningPort{}
for _, tab := range tabs {
if tab.LocalAddr == nil || tab.LocalAddr.Port < workspacesdk.AgentMinimumListeningPort {
continue
}

// Ignore ports that we've been told to ignore.
if _, ok := lp.ignorePorts[int(tab.LocalAddr.Port)]; ok {
if tab.LocalAddr == nil {
continue
}

Expand Down
45 changes: 45 additions & 0 deletions agent/ports_supported_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//go:build linux || (windows && amd64)

package agent

import (
"net"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestOSListeningPortsGetter(t *testing.T) {
t.Parallel()

uut := &osListeningPortsGetter{
cacheDuration: 1 * time.Hour,
}

l, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
defer l.Close()

ports, err := uut.GetListeningPorts()
require.NoError(t, err)
found := false
for _, port := range ports {
// #nosec G115 - Safe conversion as TCP port numbers are within uint16 range (0-65535)
if port.Port == uint16(l.Addr().(*net.TCPAddr).Port) {
found = true
break
}
}
require.True(t, found)

// check that we cache the ports
err = l.Close()
require.NoError(t, err)
portsNew, err := uut.GetListeningPorts()
require.NoError(t, err)
require.Equal(t, ports, portsNew)

// note that it's unsafe to try to assert that a port does not exist in the response
// because the OS may reallocate the port very quickly.
}
12 changes: 10 additions & 2 deletions agent/ports_unsupported.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@

package agent

import "github.com/coder/coder/v2/codersdk"
import (
"time"

func (*listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
"github.com/coder/coder/v2/codersdk"
)

type osListeningPortsGetter struct {
cacheDuration time.Duration
}

func (*osListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
// Can't scan for ports on non-linux or non-windows_amd64 systems at the
// moment. The UI will not show any "no ports found" message to the user, so
// the user won't suspect a thing.
Expand Down
Loading
Loading