Skip to content
18 changes: 12 additions & 6 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ type Options struct {
Devcontainers bool
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
Clock quartz.Clock
SocketServerEnabled bool
SocketPath string // Path for the agent socket server socket
}

Expand Down Expand Up @@ -193,6 +194,7 @@ func New(options Options) Agent {
devcontainers: options.Devcontainers,
containerAPIOptions: options.DevcontainerAPIOptions,
socketPath: options.SocketPath,
socketServerEnabled: options.SocketServerEnabled,
}
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
Expand Down Expand Up @@ -275,8 +277,9 @@ type agent struct {
containerAPIOptions []agentcontainers.Option
containerAPI *agentcontainers.API

socketPath string
socketServer *agentsocket.Server
socketServerEnabled bool
socketPath string
socketServer *agentsocket.Server
}

func (a *agent) TailnetConn() *tailnet.Conn {
Expand Down Expand Up @@ -364,14 +367,17 @@ func (a *agent) init() {

// initSocketServer initializes server that allows direct communication with a workspace agent using IPC.
func (a *agent) initSocketServer() {
if a.socketPath == "" {
a.logger.Info(a.hardCtx, "socket server disabled (no path configured)")
if !a.socketServerEnabled {
a.logger.Info(a.hardCtx, "socket server is disabled")
return
}

server, err := agentsocket.NewServer(a.socketPath, a.logger.Named("socket"))
server, err := agentsocket.NewServer(
a.logger.Named("socket"),
agentsocket.WithPath(a.socketPath),
)
if err != nil {
a.logger.Warn(a.hardCtx, "failed to create socket server", slog.Error(err))
a.logger.Warn(a.hardCtx, "failed to create socket server", slog.Error(err), slog.F("path", a.socketPath))
return
}

Expand Down
146 changes: 146 additions & 0 deletions agent/agentsocket/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package agentsocket

import (
"context"

"golang.org/x/xerrors"
"storj.io/drpc"
"storj.io/drpc/drpcconn"

"github.com/coder/coder/v2/agent/agentsocket/proto"
"github.com/coder/coder/v2/agent/unit"
)

// Option represents a configuration option for NewClient.
type Option func(*options)

type options struct {
path string
}

// WithPath sets the socket path. If not provided or empty, the client will
// auto-discover the default socket path.
func WithPath(path string) Option {
return func(opts *options) {
if path == "" {
return
}
opts.path = path
}
}

// Client provides a client for communicating with the workspace agentsocket API.
type Client struct {
client proto.DRPCAgentSocketClient
conn drpc.Conn
}

// NewClient creates a new socket client and opens a connection to the socket.
// If path is not provided via WithPath or is empty, it will auto-discover the
// default socket path.
func NewClient(ctx context.Context, opts ...Option) (*Client, error) {
options := &options{}
for _, opt := range opts {
opt(options)
}

conn, err := dialSocket(ctx, options.path)
if err != nil {
return nil, xerrors.Errorf("connect to socket: %w", err)
}

drpcConn := drpcconn.New(conn)
client := proto.NewDRPCAgentSocketClient(drpcConn)

return &Client{
client: client,
conn: drpcConn,
}, nil
}

// Close closes the socket connection.
func (c *Client) Close() error {
return c.conn.Close()
}

// Ping sends a ping request to the agent.
func (c *Client) Ping(ctx context.Context) error {
_, err := c.client.Ping(ctx, &proto.PingRequest{})
return err
}

// SyncStart starts a unit in the dependency graph.
func (c *Client) SyncStart(ctx context.Context, unitName unit.ID) error {
_, err := c.client.SyncStart(ctx, &proto.SyncStartRequest{
Unit: string(unitName),
})
return err
}

// SyncWant declares a dependency between units.
func (c *Client) SyncWant(ctx context.Context, unitName, dependsOn unit.ID) error {
_, err := c.client.SyncWant(ctx, &proto.SyncWantRequest{
Unit: string(unitName),
DependsOn: string(dependsOn),
})
return err
}

// SyncComplete marks a unit as complete in the dependency graph.
func (c *Client) SyncComplete(ctx context.Context, unitName unit.ID) error {
_, err := c.client.SyncComplete(ctx, &proto.SyncCompleteRequest{
Unit: string(unitName),
})
return err
}

// SyncReady requests whether a unit is ready to be started. That is, all dependencies are satisfied.
func (c *Client) SyncReady(ctx context.Context, unitName unit.ID) (bool, error) {
resp, err := c.client.SyncReady(ctx, &proto.SyncReadyRequest{
Unit: string(unitName),
})
return resp.Ready, err
}

// SyncStatus gets the status of a unit and its dependencies.
func (c *Client) SyncStatus(ctx context.Context, unitName unit.ID) (SyncStatusResponse, error) {
resp, err := c.client.SyncStatus(ctx, &proto.SyncStatusRequest{
Unit: string(unitName),
})
if err != nil {
return SyncStatusResponse{}, err
}

var dependencies []DependencyInfo
for _, dep := range resp.Dependencies {
dependencies = append(dependencies, DependencyInfo{
DependsOn: unit.ID(dep.DependsOn),
RequiredStatus: unit.Status(dep.RequiredStatus),
CurrentStatus: unit.Status(dep.CurrentStatus),
IsSatisfied: dep.IsSatisfied,
})
}

return SyncStatusResponse{
UnitName: unitName,
Status: unit.Status(resp.Status),
IsReady: resp.IsReady,
Dependencies: dependencies,
}, nil
}

// SyncStatusResponse contains the status information for a unit.
type SyncStatusResponse struct {
UnitName unit.ID `table:"unit,default_sort" json:"unit_name"`
Status unit.Status `table:"status" json:"status"`
IsReady bool `table:"ready" json:"is_ready"`
Dependencies []DependencyInfo `table:"dependencies" json:"dependencies"`
}

// DependencyInfo contains information about a unit dependency.
type DependencyInfo struct {
DependsOn unit.ID `table:"depends on,default_sort" json:"depends_on"`
RequiredStatus unit.Status `table:"required status" json:"required_status"`
CurrentStatus unit.Status `table:"current status" json:"current_status"`
IsSatisfied bool `table:"satisfied" json:"is_satisfied"`
}
35 changes: 9 additions & 26 deletions agent/agentsocket/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"sync"

"golang.org/x/xerrors"

"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"

Expand All @@ -17,21 +16,6 @@ import (
"github.com/coder/coder/v2/codersdk/drpcsdk"
)

// Client wraps a DRPC client with connection management.
// This type is defined here so it's available on all platforms.
type Client struct {
proto.DRPCAgentSocketClient
conn net.Conn
}

// Close closes the client connection.
func (c *Client) Close() error {
if c.conn != nil {
return c.conn.Close()
}
return nil
}

// Server provides access to the DRPCAgentSocketService via a Unix domain socket.
// Do not invoke Server{} directly. Use NewServer() instead.
type Server struct {
Expand All @@ -47,11 +31,17 @@ type Server struct {
wg sync.WaitGroup
}

func NewServer(path string, logger slog.Logger) (*Server, error) {
// NewServer creates a new agent socket server.
func NewServer(logger slog.Logger, opts ...Option) (*Server, error) {
options := &options{}
for _, opt := range opts {
opt(options)
}

logger = logger.Named("agentsocket-server")
server := &Server{
logger: logger,
path: path,
path: options.path,
service: &DRPCAgentSocketService{
logger: logger,
unitManager: unit.NewManager(),
Expand All @@ -75,14 +65,6 @@ func NewServer(path string, logger slog.Logger) (*Server, error) {
},
})

if server.path == "" {
var err error
server.path, err = getDefaultSocketPath()
if err != nil {
return nil, xerrors.Errorf("get default socket path: %w", err)
}
}

listener, err := createSocket(server.path)
if err != nil {
return nil, xerrors.Errorf("create socket: %w", err)
Expand All @@ -105,6 +87,7 @@ func NewServer(path string, logger slog.Logger) (*Server, error) {
return server, nil
}

// Close stops the server and cleans up resources.
func (s *Server) Close() error {
s.mu.Lock()

Expand Down
Loading
Loading