Skip to content

Commit b666d52

Browse files
hugodutkaThomasK33
andauthored
feat(codersdk/toolsdk): add MCP workspace bash background parameter (#19034)
Addresses coder/internal#820 --------- Signed-off-by: Thomas Kosiewski <tk@coder.com> Co-authored-by: Thomas Kosiewski <tk@coder.com>
1 parent bf78966 commit b666d52

File tree

3 files changed

+202
-35
lines changed

3 files changed

+202
-35
lines changed

codersdk/toolsdk/bash.go

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ import (
2121
)
2222

2323
type WorkspaceBashArgs struct {
24-
Workspace string `json:"workspace"`
25-
Command string `json:"command"`
26-
TimeoutMs int `json:"timeout_ms,omitempty"`
24+
Workspace string `json:"workspace"`
25+
Command string `json:"command"`
26+
TimeoutMs int `json:"timeout_ms,omitempty"`
27+
Background bool `json:"background,omitempty"`
2728
}
2829

2930
type WorkspaceBashResult struct {
@@ -50,9 +51,13 @@ The workspace parameter supports various formats:
5051
The timeout_ms parameter specifies the command timeout in milliseconds (defaults to 60000ms, maximum of 300000ms).
5152
If the command times out, all output captured up to that point is returned with a cancellation message.
5253
54+
For background commands (background: true), output is captured until the timeout is reached, then the command
55+
continues running in the background. The captured output is returned as the result.
56+
5357
Examples:
5458
- workspace: "my-workspace", command: "ls -la"
5559
- workspace: "john/dev-env", command: "git status", timeout_ms: 30000
60+
- workspace: "my-workspace", command: "npm run dev", background: true, timeout_ms: 10000
5661
- workspace: "my-workspace.main", command: "docker ps"`,
5762
Schema: aisdk.Schema{
5863
Properties: map[string]any{
@@ -70,6 +75,10 @@ Examples:
7075
"default": 60000,
7176
"minimum": 1,
7277
},
78+
"background": map[string]any{
79+
"type": "boolean",
80+
"description": "Whether to run the command in the background. Output is captured until timeout, then the command continues running in the background.",
81+
},
7382
},
7483
Required: []string{"workspace", "command"},
7584
},
@@ -137,23 +146,35 @@ Examples:
137146

138147
// Set default timeout if not specified (60 seconds)
139148
timeoutMs := args.TimeoutMs
149+
defaultTimeoutMs := 60000
140150
if timeoutMs <= 0 {
141-
timeoutMs = 60000
151+
timeoutMs = defaultTimeoutMs
152+
}
153+
command := args.Command
154+
if args.Background {
155+
// For background commands, use nohup directly to ensure they survive SSH session
156+
// termination. This captures output normally but allows the process to continue
157+
// running even after the SSH connection closes.
158+
command = fmt.Sprintf("nohup %s </dev/null 2>&1", args.Command)
142159
}
143160

144-
// Create context with timeout
145-
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond)
146-
defer cancel()
161+
// Create context with command timeout (replace the broader MCP timeout)
162+
commandCtx, commandCancel := context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond)
163+
defer commandCancel()
147164

148165
// Execute command with timeout handling
149-
output, err := executeCommandWithTimeout(ctx, session, args.Command)
166+
output, err := executeCommandWithTimeout(commandCtx, session, command)
150167
outputStr := strings.TrimSpace(string(output))
151168

152169
// Handle command execution results
153170
if err != nil {
154171
// Check if the command timed out
155-
if errors.Is(context.Cause(ctx), context.DeadlineExceeded) {
156-
outputStr += "\nCommand canceled due to timeout"
172+
if errors.Is(context.Cause(commandCtx), context.DeadlineExceeded) {
173+
if args.Background {
174+
outputStr += "\nCommand continues running in background"
175+
} else {
176+
outputStr += "\nCommand canceled due to timeout"
177+
}
157178
return WorkspaceBashResult{
158179
Output: outputStr,
159180
ExitCode: 124,
@@ -387,21 +408,27 @@ func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, comm
387408
return safeWriter.Bytes(), err
388409
case <-ctx.Done():
389410
// Context was canceled (timeout or other cancellation)
390-
// Close the session to stop the command
391-
_ = session.Close()
411+
// Close the session to stop the command, but handle errors gracefully
412+
closeErr := session.Close()
392413

393-
// Give a brief moment to collect any remaining output
394-
timer := time.NewTimer(50 * time.Millisecond)
414+
// Give a brief moment to collect any remaining output and for goroutines to finish
415+
timer := time.NewTimer(100 * time.Millisecond)
395416
defer timer.Stop()
396417

397418
select {
398419
case <-timer.C:
399420
// Timer expired, return what we have
421+
break
400422
case err := <-done:
401423
// Command finished during grace period
402-
return safeWriter.Bytes(), err
424+
if closeErr == nil {
425+
return safeWriter.Bytes(), err
426+
}
427+
// If session close failed, prioritize the context error
428+
break
403429
}
404430

431+
// Return the collected output with the context error
405432
return safeWriter.Bytes(), context.Cause(ctx)
406433
}
407434
}
@@ -421,5 +448,9 @@ func (sw *syncWriter) Write(p []byte) (n int, err error) {
421448
func (sw *syncWriter) Bytes() []byte {
422449
sw.mu.Lock()
423450
defer sw.mu.Unlock()
424-
return sw.w.Bytes()
451+
// Return a copy to prevent race conditions with the underlying buffer
452+
b := sw.w.Bytes()
453+
result := make([]byte, len(b))
454+
copy(result, b)
455+
return result
425456
}

codersdk/toolsdk/bash_test.go

Lines changed: 143 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/coder/coder/v2/agent/agenttest"
1010
"github.com/coder/coder/v2/coderd/coderdtest"
1111
"github.com/coder/coder/v2/codersdk/toolsdk"
12+
"github.com/coder/coder/v2/testutil"
1213
)
1314

1415
func TestWorkspaceBash(t *testing.T) {
@@ -174,8 +175,6 @@ func TestWorkspaceBashTimeout(t *testing.T) {
174175

175176
// Test that the TimeoutMs field can be set and read correctly
176177
args := toolsdk.WorkspaceBashArgs{
177-
Workspace: "test-workspace",
178-
Command: "echo test",
179178
TimeoutMs: 0, // Should default to 60000 in handler
180179
}
181180

@@ -192,8 +191,6 @@ func TestWorkspaceBashTimeout(t *testing.T) {
192191

193192
// Test that negative values can be set and will be handled by the default logic
194193
args := toolsdk.WorkspaceBashArgs{
195-
Workspace: "test-workspace",
196-
Command: "echo test",
197194
TimeoutMs: -100,
198195
}
199196

@@ -279,7 +276,7 @@ func TestWorkspaceBashTimeoutIntegration(t *testing.T) {
279276
TimeoutMs: 2000, // 2 seconds timeout - should timeout after first echo
280277
}
281278

282-
result, err := toolsdk.WorkspaceBash.Handler(t.Context(), deps, args)
279+
result, err := testTool(t, toolsdk.WorkspaceBash, deps, args)
283280

284281
// Should not error (timeout is handled gracefully)
285282
require.NoError(t, err)
@@ -313,15 +310,15 @@ func TestWorkspaceBashTimeoutIntegration(t *testing.T) {
313310

314311
deps, err := toolsdk.NewDeps(client)
315312
require.NoError(t, err)
316-
ctx := context.Background()
317313

318314
args := toolsdk.WorkspaceBashArgs{
319315
Workspace: workspace.Name,
320316
Command: `echo "normal command"`, // Quick command that should complete normally
321317
TimeoutMs: 5000, // 5 second timeout - plenty of time
322318
}
323319

324-
result, err := toolsdk.WorkspaceBash.Handler(ctx, deps, args)
320+
// Use testTool to register the tool as tested and satisfy coverage validation
321+
result, err := testTool(t, toolsdk.WorkspaceBash, deps, args)
325322

326323
// Should not error
327324
require.NoError(t, err)
@@ -338,3 +335,142 @@ func TestWorkspaceBashTimeoutIntegration(t *testing.T) {
338335
require.NotContains(t, result.Output, "Command canceled due to timeout")
339336
})
340337
}
338+
339+
func TestWorkspaceBashBackgroundIntegration(t *testing.T) {
340+
t.Parallel()
341+
342+
t.Run("BackgroundCommandCapturesOutput", func(t *testing.T) {
343+
t.Parallel()
344+
345+
client, workspace, agentToken := setupWorkspaceForAgent(t)
346+
347+
// Start the agent and wait for it to be fully ready
348+
_ = agenttest.New(t, client.URL, agentToken)
349+
350+
// Wait for workspace agents to be ready
351+
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
352+
353+
deps, err := toolsdk.NewDeps(client)
354+
require.NoError(t, err)
355+
356+
args := toolsdk.WorkspaceBashArgs{
357+
Workspace: workspace.Name,
358+
Command: `echo "started" && sleep 60 && echo "completed"`, // Command that would take 60+ seconds
359+
Background: true, // Run in background
360+
TimeoutMs: 2000, // 2 second timeout
361+
}
362+
363+
result, err := testTool(t, toolsdk.WorkspaceBash, deps, args)
364+
365+
// Should not error
366+
require.NoError(t, err)
367+
368+
t.Logf("Background result: exitCode=%d, output=%q", result.ExitCode, result.Output)
369+
370+
// Should have exit code 124 (timeout) since command times out
371+
require.Equal(t, 124, result.ExitCode)
372+
373+
// Should capture output up to timeout point
374+
require.Contains(t, result.Output, "started", "Should contain output captured before timeout")
375+
376+
// Should NOT contain the second echo (it never executed due to timeout)
377+
require.NotContains(t, result.Output, "completed", "Should not contain output after timeout")
378+
379+
// Should contain background continuation message
380+
require.Contains(t, result.Output, "Command continues running in background")
381+
})
382+
383+
t.Run("BackgroundVsNormalExecution", func(t *testing.T) {
384+
t.Parallel()
385+
386+
client, workspace, agentToken := setupWorkspaceForAgent(t)
387+
388+
// Start the agent and wait for it to be fully ready
389+
_ = agenttest.New(t, client.URL, agentToken)
390+
391+
// Wait for workspace agents to be ready
392+
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
393+
394+
deps, err := toolsdk.NewDeps(client)
395+
require.NoError(t, err)
396+
397+
// First run the same command in normal mode
398+
normalArgs := toolsdk.WorkspaceBashArgs{
399+
Workspace: workspace.Name,
400+
Command: `echo "hello world"`,
401+
Background: false,
402+
}
403+
404+
normalResult, err := toolsdk.WorkspaceBash.Handler(t.Context(), deps, normalArgs)
405+
require.NoError(t, err)
406+
407+
// Normal mode should return the actual output
408+
require.Equal(t, 0, normalResult.ExitCode)
409+
require.Equal(t, "hello world", normalResult.Output)
410+
411+
// Now run the same command in background mode
412+
backgroundArgs := toolsdk.WorkspaceBashArgs{
413+
Workspace: workspace.Name,
414+
Command: `echo "hello world"`,
415+
Background: true,
416+
}
417+
418+
backgroundResult, err := testTool(t, toolsdk.WorkspaceBash, deps, backgroundArgs)
419+
require.NoError(t, err)
420+
421+
t.Logf("Normal result: %q", normalResult.Output)
422+
t.Logf("Background result: %q", backgroundResult.Output)
423+
424+
// Background mode should also return the actual output since command completes quickly
425+
require.Equal(t, 0, backgroundResult.ExitCode)
426+
require.Equal(t, "hello world", backgroundResult.Output)
427+
})
428+
429+
t.Run("BackgroundCommandContinuesAfterTimeout", func(t *testing.T) {
430+
t.Parallel()
431+
432+
client, workspace, agentToken := setupWorkspaceForAgent(t)
433+
434+
// Start the agent and wait for it to be fully ready
435+
_ = agenttest.New(t, client.URL, agentToken)
436+
437+
// Wait for workspace agents to be ready
438+
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
439+
440+
deps, err := toolsdk.NewDeps(client)
441+
require.NoError(t, err)
442+
443+
args := toolsdk.WorkspaceBashArgs{
444+
Workspace: workspace.Name,
445+
Command: `echo "started" && sleep 4 && echo "done" > /tmp/bg-test-done`, // Command that will timeout but continue
446+
TimeoutMs: 2000, // 2000ms timeout (shorter than command duration)
447+
Background: true, // Run in background
448+
}
449+
450+
result, err := testTool(t, toolsdk.WorkspaceBash, deps, args)
451+
452+
// Should not error but should timeout
453+
require.NoError(t, err)
454+
455+
t.Logf("Background with timeout result: exitCode=%d, output=%q", result.ExitCode, result.Output)
456+
457+
// Should have timeout exit code
458+
require.Equal(t, 124, result.ExitCode)
459+
460+
// Should capture output before timeout
461+
require.Contains(t, result.Output, "started", "Should contain output captured before timeout")
462+
463+
// Should contain background continuation message
464+
require.Contains(t, result.Output, "Command continues running in background")
465+
466+
// Wait for the background command to complete (even though SSH session timed out)
467+
require.Eventually(t, func() bool {
468+
checkArgs := toolsdk.WorkspaceBashArgs{
469+
Workspace: workspace.Name,
470+
Command: `cat /tmp/bg-test-done 2>/dev/null || echo "not found"`,
471+
}
472+
checkResult, err := toolsdk.WorkspaceBash.Handler(t.Context(), deps, checkArgs)
473+
return err == nil && checkResult.Output == "done"
474+
}, testutil.WaitMedium, testutil.IntervalMedium, "Background command should continue running and complete after timeout")
475+
})
476+
}

codersdk/toolsdk/toolsdk_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ var testedTools sync.Map
456456
// This is to mimic how we expect external callers to use the tool.
457457
func testTool[Arg, Ret any](t *testing.T, tool toolsdk.Tool[Arg, Ret], tb toolsdk.Deps, args Arg) (Ret, error) {
458458
t.Helper()
459-
defer func() { testedTools.Store(tool.Tool.Name, true) }()
459+
defer func() { testedTools.Store(tool.Name, true) }()
460460
toolArgs, err := json.Marshal(args)
461461
require.NoError(t, err, "failed to marshal args")
462462
result, err := tool.Generic().Handler(t.Context(), tb, toolArgs)
@@ -625,23 +625,23 @@ func TestToolSchemaFields(t *testing.T) {
625625

626626
// Test that all tools have the required Schema fields (Properties and Required)
627627
for _, tool := range toolsdk.All {
628-
t.Run(tool.Tool.Name, func(t *testing.T) {
628+
t.Run(tool.Name, func(t *testing.T) {
629629
t.Parallel()
630630

631631
// Check that Properties is not nil
632-
require.NotNil(t, tool.Tool.Schema.Properties,
633-
"Tool %q missing Schema.Properties", tool.Tool.Name)
632+
require.NotNil(t, tool.Schema.Properties,
633+
"Tool %q missing Schema.Properties", tool.Name)
634634

635635
// Check that Required is not nil
636-
require.NotNil(t, tool.Tool.Schema.Required,
637-
"Tool %q missing Schema.Required", tool.Tool.Name)
636+
require.NotNil(t, tool.Schema.Required,
637+
"Tool %q missing Schema.Required", tool.Name)
638638

639639
// Ensure Properties has entries for all required fields
640-
for _, requiredField := range tool.Tool.Schema.Required {
641-
_, exists := tool.Tool.Schema.Properties[requiredField]
640+
for _, requiredField := range tool.Schema.Required {
641+
_, exists := tool.Schema.Properties[requiredField]
642642
require.True(t, exists,
643643
"Tool %q requires field %q but it is not defined in Properties",
644-
tool.Tool.Name, requiredField)
644+
tool.Name, requiredField)
645645
}
646646
})
647647
}
@@ -652,16 +652,16 @@ func TestToolSchemaFields(t *testing.T) {
652652
func TestMain(m *testing.M) {
653653
// Initialize testedTools
654654
for _, tool := range toolsdk.All {
655-
testedTools.Store(tool.Tool.Name, false)
655+
testedTools.Store(tool.Name, false)
656656
}
657657

658658
code := m.Run()
659659

660660
// Ensure all tools have been tested
661661
var untested []string
662662
for _, tool := range toolsdk.All {
663-
if tested, ok := testedTools.Load(tool.Tool.Name); !ok || !tested.(bool) {
664-
untested = append(untested, tool.Tool.Name)
663+
if tested, ok := testedTools.Load(tool.Name); !ok || !tested.(bool) {
664+
untested = append(untested, tool.Name)
665665
}
666666
}
667667

0 commit comments

Comments
 (0)