diff --git a/codersdk/toolsdk/bash.go b/codersdk/toolsdk/bash.go index e45ca6a49e29a..5fb15843f1bf1 100644 --- a/codersdk/toolsdk/bash.go +++ b/codersdk/toolsdk/bash.go @@ -1,11 +1,14 @@ package toolsdk import ( + "bytes" "context" "errors" "fmt" "io" "strings" + "sync" + "time" gossh "golang.org/x/crypto/ssh" "golang.org/x/xerrors" @@ -20,6 +23,7 @@ import ( type WorkspaceBashArgs struct { Workspace string `json:"workspace"` Command string `json:"command"` + TimeoutMs int `json:"timeout_ms,omitempty"` } type WorkspaceBashResult struct { @@ -43,9 +47,12 @@ The workspace parameter supports various formats: - workspace.agent (specific agent) - owner/workspace.agent +The timeout_ms parameter specifies the command timeout in milliseconds (defaults to 60000ms, maximum of 300000ms). +If the command times out, all output captured up to that point is returned with a cancellation message. + Examples: - workspace: "my-workspace", command: "ls -la" -- workspace: "john/dev-env", command: "git status" +- workspace: "john/dev-env", command: "git status", timeout_ms: 30000 - workspace: "my-workspace.main", command: "docker ps"`, Schema: aisdk.Schema{ Properties: map[string]any{ @@ -57,11 +64,17 @@ Examples: "type": "string", "description": "The bash command to execute in the workspace.", }, + "timeout_ms": map[string]any{ + "type": "integer", + "description": "Command timeout in milliseconds. Defaults to 60000ms (60 seconds) if not specified.", + "default": 60000, + "minimum": 1, + }, }, Required: []string{"workspace", "command"}, }, }, - Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (WorkspaceBashResult, error) { + Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (res WorkspaceBashResult, err error) { if args.Workspace == "" { return WorkspaceBashResult{}, xerrors.New("workspace name cannot be empty") } @@ -69,6 +82,9 @@ Examples: return WorkspaceBashResult{}, xerrors.New("command cannot be empty") } + ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, xerrors.New("MCP handler timeout after 5 min")) + defer cancel() + // Normalize workspace input to handle various formats workspaceName := NormalizeWorkspaceInput(args.Workspace) @@ -119,23 +135,42 @@ Examples: } defer session.Close() - // Execute command and capture output - output, err := session.CombinedOutput(args.Command) + // Set default timeout if not specified (60 seconds) + timeoutMs := args.TimeoutMs + if timeoutMs <= 0 { + timeoutMs = 60000 + } + + // Create context with timeout + ctx, cancel = context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond) + defer cancel() + + // Execute command with timeout handling + output, err := executeCommandWithTimeout(ctx, session, args.Command) outputStr := strings.TrimSpace(string(output)) + // Handle command execution results if err != nil { - // Check if it's an SSH exit error to get the exit code - var exitErr *gossh.ExitError - if errors.As(err, &exitErr) { + // Check if the command timed out + if errors.Is(context.Cause(ctx), context.DeadlineExceeded) { + outputStr += "\nCommand canceled due to timeout" return WorkspaceBashResult{ Output: outputStr, - ExitCode: exitErr.ExitStatus(), + ExitCode: 124, }, nil } - // For other errors, return exit code 1 + + // Extract exit code from SSH error if available + exitCode := 1 + var exitErr *gossh.ExitError + if errors.As(err, &exitErr) { + exitCode = exitErr.ExitStatus() + } + + // For other errors, use standard timeout or generic error code return WorkspaceBashResult{ Output: outputStr, - ExitCode: 1, + ExitCode: exitCode, }, nil } @@ -292,3 +327,99 @@ func NormalizeWorkspaceInput(input string) string { return normalized } + +// executeCommandWithTimeout executes a command with timeout support +func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, command string) ([]byte, error) { + // Set up pipes to capture output + stdoutPipe, err := session.StdoutPipe() + if err != nil { + return nil, xerrors.Errorf("failed to create stdout pipe: %w", err) + } + + stderrPipe, err := session.StderrPipe() + if err != nil { + return nil, xerrors.Errorf("failed to create stderr pipe: %w", err) + } + + // Start the command + if err := session.Start(command); err != nil { + return nil, xerrors.Errorf("failed to start command: %w", err) + } + + // Create a thread-safe buffer for combined output + var output bytes.Buffer + var mu sync.Mutex + safeWriter := &syncWriter{w: &output, mu: &mu} + + // Use io.MultiWriter to combine stdout and stderr + multiWriter := io.MultiWriter(safeWriter) + + // Channel to signal when command completes + done := make(chan error, 1) + + // Start goroutine to copy output and wait for completion + go func() { + // Copy stdout and stderr concurrently + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, _ = io.Copy(multiWriter, stdoutPipe) + }() + + go func() { + defer wg.Done() + _, _ = io.Copy(multiWriter, stderrPipe) + }() + + // Wait for all output to be copied + wg.Wait() + + // Wait for the command to complete + done <- session.Wait() + }() + + // Wait for either completion or context cancellation + select { + case err := <-done: + // Command completed normally + return safeWriter.Bytes(), err + case <-ctx.Done(): + // Context was canceled (timeout or other cancellation) + // Close the session to stop the command + _ = session.Close() + + // Give a brief moment to collect any remaining output + timer := time.NewTimer(50 * time.Millisecond) + defer timer.Stop() + + select { + case <-timer.C: + // Timer expired, return what we have + case err := <-done: + // Command finished during grace period + return safeWriter.Bytes(), err + } + + return safeWriter.Bytes(), context.Cause(ctx) + } +} + +// syncWriter is a thread-safe writer +type syncWriter struct { + w *bytes.Buffer + mu *sync.Mutex +} + +func (sw *syncWriter) Write(p []byte) (n int, err error) { + sw.mu.Lock() + defer sw.mu.Unlock() + return sw.w.Write(p) +} + +func (sw *syncWriter) Bytes() []byte { + sw.mu.Lock() + defer sw.mu.Unlock() + return sw.w.Bytes() +} diff --git a/codersdk/toolsdk/bash_test.go b/codersdk/toolsdk/bash_test.go index 474071fc45acb..53ac480039278 100644 --- a/codersdk/toolsdk/bash_test.go +++ b/codersdk/toolsdk/bash_test.go @@ -6,6 +6,8 @@ import ( "github.com/stretchr/testify/require" + "github.com/coder/coder/v2/agent/agenttest" + "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/codersdk/toolsdk" ) @@ -40,7 +42,7 @@ func TestWorkspaceBash(t *testing.T) { t.Run("ErrorScenarios", func(t *testing.T) { t.Parallel() - deps := toolsdk.Deps{} // Empty deps will cause client access to fail + deps := toolsdk.Deps{} ctx := context.Background() // Test input validation errors (these should fail before client access) @@ -159,3 +161,180 @@ func TestAllToolsIncludesBash(t *testing.T) { } require.True(t, found, "WorkspaceBash tool should be included in toolsdk.All") } + +// Note: Unit testing ExecuteCommandWithTimeout is challenging because it expects +// a concrete SSH session type. The integration tests above demonstrate the +// timeout functionality with a real SSH connection and mock clock. + +func TestWorkspaceBashTimeout(t *testing.T) { + t.Parallel() + + t.Run("TimeoutDefaultValue", func(t *testing.T) { + t.Parallel() + + // Test that the TimeoutMs field can be set and read correctly + args := toolsdk.WorkspaceBashArgs{ + Workspace: "test-workspace", + Command: "echo test", + TimeoutMs: 0, // Should default to 60000 in handler + } + + // Verify that the TimeoutMs field exists and can be set + require.Equal(t, 0, args.TimeoutMs) + + // Test setting a positive value + args.TimeoutMs = 5000 + require.Equal(t, 5000, args.TimeoutMs) + }) + + t.Run("TimeoutNegativeValue", func(t *testing.T) { + t.Parallel() + + // Test that negative values can be set and will be handled by the default logic + args := toolsdk.WorkspaceBashArgs{ + Workspace: "test-workspace", + Command: "echo test", + TimeoutMs: -100, + } + + require.Equal(t, -100, args.TimeoutMs) + + // The actual defaulting to 60000 happens inside the handler + // We can't test it without a full integration test setup + }) + + t.Run("TimeoutSchemaValidation", func(t *testing.T) { + t.Parallel() + + tool := toolsdk.WorkspaceBash + + // Check that timeout_ms is in the schema + require.Contains(t, tool.Schema.Properties, "timeout_ms") + + timeoutProperty := tool.Schema.Properties["timeout_ms"].(map[string]any) + require.Equal(t, "integer", timeoutProperty["type"]) + require.Equal(t, 60000, timeoutProperty["default"]) + require.Equal(t, 1, timeoutProperty["minimum"]) + require.Contains(t, timeoutProperty["description"], "timeout in milliseconds") + }) + + t.Run("TimeoutDescriptionUpdated", func(t *testing.T) { + t.Parallel() + + tool := toolsdk.WorkspaceBash + + // Check that description mentions timeout functionality + require.Contains(t, tool.Description, "timeout_ms parameter") + require.Contains(t, tool.Description, "defaults to 60000ms") + require.Contains(t, tool.Description, "timeout_ms: 30000") + }) + + t.Run("TimeoutCommandScenario", func(t *testing.T) { + t.Parallel() + + // Scenario: echo "123"; sleep 60; echo "456" with 5ms timeout + // In this scenario, we'd expect to see "123" in the output and a cancellation message + args := toolsdk.WorkspaceBashArgs{ + Workspace: "test-workspace", + Command: `echo "123"; sleep 60; echo "456"`, // This command would take 60+ seconds + TimeoutMs: 5, // 5ms timeout - should timeout after first echo + } + + // Verify the args are structured correctly for the intended test scenario + require.Equal(t, "test-workspace", args.Workspace) + require.Contains(t, args.Command, `echo "123"`) + require.Contains(t, args.Command, "sleep 60") + require.Contains(t, args.Command, `echo "456"`) + require.Equal(t, 5, args.TimeoutMs) + + // Note: The actual timeout behavior would need to be tested with a real workspace + // This test just verifies the structure is correct for the timeout scenario + }) +} + +func TestWorkspaceBashTimeoutIntegration(t *testing.T) { + t.Parallel() + + t.Run("ActualTimeoutBehavior", func(t *testing.T) { + t.Parallel() + + // Scenario: echo "123"; sleep 60; echo "456" with 5s timeout + // In this scenario, we'd expect to see "123" in the output and a cancellation message + + client, workspace, agentToken := setupWorkspaceForAgent(t) + + // Start the agent and wait for it to be fully ready + _ = agenttest.New(t, client.URL, agentToken) + + // Wait for workspace agents to be ready like other SSH tests do + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + + // Use real clock for integration test + deps, err := toolsdk.NewDeps(client) + require.NoError(t, err) + + args := toolsdk.WorkspaceBashArgs{ + Workspace: workspace.Name, + Command: `echo "123" && sleep 60 && echo "456"`, // This command would take 60+ seconds + TimeoutMs: 2000, // 2 seconds timeout - should timeout after first echo + } + + result, err := toolsdk.WorkspaceBash.Handler(t.Context(), deps, args) + + // Should not error (timeout is handled gracefully) + require.NoError(t, err) + + t.Logf("Test results: exitCode=%d, output=%q, error=%v", result.ExitCode, result.Output, err) + + // Should have a non-zero exit code (timeout or error) + require.NotEqual(t, 0, result.ExitCode, "Expected non-zero exit code for timeout") + + t.Logf("result.Output: %s", result.Output) + + // Should contain the first echo output + require.Contains(t, result.Output, "123") + + // Should NOT contain the second echo (it never executed due to timeout) + require.NotContains(t, result.Output, "456", "Should not contain output after sleep") + }) + + t.Run("NormalCommandExecution", func(t *testing.T) { + t.Parallel() + + // Test that normal commands still work with timeout functionality present + + client, workspace, agentToken := setupWorkspaceForAgent(t) + + // Start the agent and wait for it to be fully ready + _ = agenttest.New(t, client.URL, agentToken) + + // Wait for workspace agents to be ready + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + + deps, err := toolsdk.NewDeps(client) + require.NoError(t, err) + ctx := context.Background() + + args := toolsdk.WorkspaceBashArgs{ + Workspace: workspace.Name, + Command: `echo "normal command"`, // Quick command that should complete normally + TimeoutMs: 5000, // 5 second timeout - plenty of time + } + + result, err := toolsdk.WorkspaceBash.Handler(ctx, deps, args) + + // Should not error + require.NoError(t, err) + + t.Logf("result.Output: %s", result.Output) + + // Should have exit code 0 (success) + require.Equal(t, 0, result.ExitCode) + + // Should contain the expected output + require.Equal(t, "normal command", result.Output) + + // Should NOT contain timeout message + require.NotContains(t, result.Output, "Command canceled due to timeout") + }) +}