Skip to content

Commit b666d52

Browse files
hugodutkaThomasK33
andauthored
feat(codersdk/toolsdk): add MCP workspace bash background parameter (#19034)
Addresses coder/internal#820 --------- Signed-off-by: Thomas Kosiewski <tk@coder.com> Co-authored-by: Thomas Kosiewski <tk@coder.com>
1 parent bf78966 commit b666d52

File tree

3 files changed

+202
-35
lines changed

3 files changed

+202
-35
lines changed

codersdk/toolsdk/bash.go

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ import (
2121
)
2222

2323
type WorkspaceBashArgs struct {
24-
Workspace string `json:"workspace"`
25-
Command string `json:"command"`
26-
TimeoutMs int `json:"timeout_ms,omitempty"`
24+
Workspace string `json:"workspace"`
25+
Command string `json:"command"`
26+
TimeoutMs int `json:"timeout_ms,omitempty"`
27+
Background bool `json:"background,omitempty"`
2728
}
2829

2930
type WorkspaceBashResult struct {
@@ -50,9 +51,13 @@ The workspace parameter supports various formats:
5051
The timeout_ms parameter specifies the command timeout in milliseconds (defaults to 60000ms, maximum of 300000ms).
5152
If the command times out, all output captured up to that point is returned with a cancellation message.
5253
54+
For background commands (background: true), output is captured until the timeout is reached, then the command
55+
continues running in the background. The captured output is returned as the result.
56+
5357
Examples:
5458
- workspace: "my-workspace", command: "ls -la"
5559
- workspace: "john/dev-env", command: "git status", timeout_ms: 30000
60+
- workspace: "my-workspace", command: "npm run dev", background: true, timeout_ms: 10000
5661
- workspace: "my-workspace.main", command: "docker ps"`,
5762
Schema: aisdk.Schema{
5863
Properties: map[string]any{
@@ -70,6 +75,10 @@ Examples:
7075
"default": 60000,
7176
"minimum": 1,
7277
},
78+
"background": map[string]any{
79+
"type": "boolean",
80+
"description": "Whether to run the command in the background. Output is captured until timeout, then the command continues running in the background.",
81+
},
7382
},
7483
Required: []string{"workspace", "command"},
7584
},
@@ -137,23 +146,35 @@ Examples:
137146

138147
// Set default timeout if not specified (60 seconds)
139148
timeoutMs := args.TimeoutMs
149+
defaultTimeoutMs := 60000
140150
if timeoutMs <= 0 {
141-
timeoutMs = 60000
151+
timeoutMs = defaultTimeoutMs
152+
}
153+
command := args.Command
154+
if args.Background {
155+
// For background commands, use nohup directly to ensure they survive SSH session
156+
// termination. This captures output normally but allows the process to continue
157+
// running even after the SSH connection closes.
158+
command = fmt.Sprintf("nohup %s </dev/null 2>&1", args.Command)
142159
}
143160

144-
// Create context with timeout
145-
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond)
146-
defer cancel()
161+
// Create context with command timeout (replace the broader MCP timeout)
162+
commandCtx, commandCancel := context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond)
163+
defer commandCancel()
147164

148165
// Execute command with timeout handling
149-
output, err := executeCommandWithTimeout(ctx, session, args.Command)
166+
output, err := executeCommandWithTimeout(commandCtx, session, command)
150167
outputStr := strings.TrimSpace(string(output))
151168

152169
// Handle command execution results
153170
if err != nil {
154171
// Check if the command timed out
155-
if errors.Is(context.Cause(ctx), context.DeadlineExceeded) {
156-
outputStr += "\nCommand canceled due to timeout"
172+
if errors.Is(context.Cause(commandCtx), context.DeadlineExceeded) {
173+
if args.Background {
174+
outputStr += "\nCommand continues running in background"
175+
} else {
176+
outputStr += "\nCommand canceled due to timeout"
177+
}
157178
return WorkspaceBashResult{
158179
Output: outputStr,
159180
ExitCode: 124,
@@ -387,21 +408,27 @@ func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, comm
387408
return safeWriter.Bytes(), err
388409
case <-ctx.Done():
389410
// Context was canceled (timeout or other cancellation)
390-
// Close the session to stop the command
391-
_ = session.Close()
411+
// Close the session to stop the command, but handle errors gracefully
412+
closeErr := session.Close()
392413

393-
// Give a brief moment to collect any remaining output
394-
timer := time.NewTimer(50 * time.Millisecond)
414+
// Give a brief moment to collect any remaining output and for goroutines to finish
415+
timer := time.NewTimer(100 * time.Millisecond)
395416
defer timer.Stop()
396417

397418
select {
398419
case <-timer.C:
399420
// Timer expired, return what we have
421+
break
400422
case err := <-done:
401423
// Command finished during grace period
402-
return safeWriter.Bytes(), err
424+
if closeErr == nil {
425+
return safeWriter.Bytes(), err
426+
}
427+
// If session close failed, prioritize the context error
428+
break
403429
}
404430

431+
// Return the collected output with the context error
405432
return safeWriter.Bytes(), context.Cause(ctx)
406433
}
407434
}
@@ -421,5 +448,9 @@ func (sw *syncWriter) Write(p []byte) (n int, err error) {
421448
func (sw *syncWriter) Bytes() []byte {
422449
sw.mu.Lock()
423450
defer sw.mu.Unlock()
424-
return sw.w.Bytes()
451+
// Return a copy to prevent race conditions with the underlying buffer
452+
b := sw.w.Bytes()
453+
result := make([]byte, len(b))
454+
copy(result, b)
455+
return result
425456
}

codersdk/toolsdk/bash_test.go

Lines changed: 143 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/coder/coder/v2/agent/agenttest"
1010
"github.com/coder/coder/v2/coderd/coderdtest"
1111
"github.com/coder/coder/v2/codersdk/toolsdk"
12+
"github.com/coder/coder/v2/testutil"
1213
)
1314

1415
func TestWorkspaceBash(t *testing.T) {
@@ -174,8 +175,6 @@ func TestWorkspaceBashTimeout(t *testing.T) {
174175

175176
// Test that the TimeoutMs field can be set and read correctly
176177
args := toolsdk.WorkspaceBashArgs{
177-
Workspace: "test-workspace",
178-
Command: "echo test",
179178
TimeoutMs: 0, // Should default to 60000 in handler
180179
}
181180

@@ -192,8 +191,6 @@ func TestWorkspaceBashTimeout(t *testing.T) {
192191

193192
// Test that negative values can be set and will be handled by the default logic
194193
args := toolsdk.WorkspaceBashArgs{
195-
Workspace: "test-workspace",
196-
Command: "echo test",
197194
TimeoutMs: -100,
198195
}
199196

@@ -279,7 +276,7 @@ func TestWorkspaceBashTimeoutIntegration(t *testing.T) {
279276
TimeoutMs: 2000, // 2 seconds timeout - should timeout after first echo
280277
}
281278

282-
result, err := toolsdk.WorkspaceBash.Handler(t.Context(), deps, args)
279+
result, err := testTool(t, toolsdk.WorkspaceBash, deps, args)
283280

284281
// Should not error (timeout is handled gracefully)
285282
require.NoError(t, err)
@@ -313,15 +310,15 @@ func TestWorkspaceBashTimeoutIntegration(t *testing.T) {
313310

314311
deps, err := toolsdk.NewDeps(client)
315312
require.NoError(t, err)
316-
ctx := context.Background()
317313

318314
args := toolsdk.WorkspaceBashArgs{
319315
Workspace: workspace.Name,
320316
Command: `echo "normal command"`, // Quick command that should complete normally
321317
TimeoutMs: 5000, // 5 second timeout - plenty of time
322318
}
323319

324-
result, err := toolsdk.WorkspaceBash.Handler(ctx, deps, args)
320+
// Use testTool to register the tool as tested and satisfy coverage validation
321+
result, err := testTool(t, toolsdk.WorkspaceBash, deps, args)
325322

326323
// Should not error
327324
require.NoError(t, err)
@@ -338,3 +335,142 @@ func TestWorkspaceBashTimeoutIntegration(t *testing.T) {
338335
require.NotContains(t, result.Output, "Command canceled due to timeout")
339336
})
340337
}
338+
339+
func TestWorkspaceBashBackgroundIntegration(t *testing.T) {
340+
t.Parallel()
341+
342+
t.Run("BackgroundCommandCapturesOutput", func(t *testing.T) {
343+
t.Parallel()
344+
345+
client, workspace, agentToken := setupWorkspaceForAgent(t)
346+
347+
// Start the agent and wait for it to be fully ready
348+
_ = agenttest.New(t, client.URL, agentToken)
349+
350+
// Wait for workspace agents to be ready
351+
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
352+
353+
deps, err := toolsdk.NewDeps(client)
354+
require.NoError(t, err)
355+
356+
args := toolsdk.WorkspaceBashArgs{
357+
Workspace: workspace.Name,
358+
Command: `echo "started" && sleep 60 && echo "completed"`, // Command that would take 60+ seconds
359+
Background: true, // Run in background
360+
TimeoutMs: 2000, // 2 second timeout
361+
}
362+
363+
result, err := testTool(t, toolsdk.WorkspaceBash, deps, args)
364+
365+
// Should not error
366+
require.NoError(t, err)
367+
368+
t.Logf("Background result: exitCode=%d, output=%q", result.ExitCode, result.Output)
369+
370+
// Should have exit code 124 (timeout) since command times out
371+
require.Equal(t, 124, result.ExitCode)
372+
373+
// Should capture output up to timeout point
374+
require.Contains(t, result.Output, "started", "Should contain output captured before timeout")
375+
376+
// Should NOT contain the second echo (it never executed due to timeout)
377+
require.NotContains(t, result.Output, "completed", "Should not contain output after timeout")
378+
379+
// Should contain background continuation message
380+
require.Contains(t, result.Output, "Command continues running in background")
381+
})
382+
383+
t.Run("BackgroundVsNormalExecution", func(t *testing.T) {
384+
t.Parallel()
385+
386+
client, workspace, agentToken := setupWorkspaceForAgent(t)
387+
388+
// Start the agent and wait for it to be fully ready
389+
_ = agenttest.New(t, client.URL, agentToken)
390+
391+
// Wait for workspace agents to be ready
392+
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
393+
394+
deps, err := toolsdk.NewDeps(client)
395+
require.NoError(t, err)
396+
397+
// First run the same command in normal mode
398+
normalArgs := toolsdk.WorkspaceBashArgs{
399+
Workspace: workspace.Name,
400+
Command: `echo "hello world"`,
401+
Background: false,
402+
}
403+
404+
normalResult, err := toolsdk.WorkspaceBash.Handler(t.Context(), deps, normalArgs)
405+
require.NoError(t, err)
406+
407+
// Normal mode should return the actual output
408+
require.Equal(t, 0, normalResult.ExitCode)
409+
require.Equal(t, "hello world", normalResult.Output)
410+
411+
// Now run the same command in background mode
412+
backgroundArgs := toolsdk.WorkspaceBashArgs{
413+
Workspace: workspace.Name,
414+
Command: `echo "hello world"`,
415+
Background: true,
416+
}
417+
418+
backgroundResult, err := testTool(t, toolsdk.WorkspaceBash, deps, backgroundArgs)
419+
require.NoError(t, err)
420+
421+
t.Logf("Normal result: %q", normalResult.Output)
422+
t.Logf("Background result: %q", backgroundResult.Output)
423+
424+
// Background mode should also return the actual output since command completes quickly
425+
require.Equal(t, 0, backgroundResult.ExitCode)
426+
require.Equal(t, "hello world", backgroundResult.Output)
427+
})
428+
429+
t.Run("BackgroundCommandContinuesAfterTimeout", func(t *testing.T) {
430+
t.Parallel()
431+
432+
client, workspace, agentToken := setupWorkspaceForAgent(t)
433+
434+
// Start the agent and wait for it to be fully ready
435+
_ = agenttest.New(t, client.URL, agentToken)
436+
437+
// Wait for workspace agents to be ready
438+
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
439+
440+
deps, err := toolsdk.NewDeps(client)
441+
require.NoError(t, err)
442+
443+
args := toolsdk.WorkspaceBashArgs{
444+
Workspace: workspace.Name,
445+
Command: `echo "started" && sleep 4 && echo "done" > /tmp/bg-test-done`, // Command that will timeout but continue
446+
TimeoutMs: 2000, // 2000ms timeout (shorter than command duration)
447+
Background: true, // Run in background
448+
}
449+
450+
result, err := testTool(t, toolsdk.WorkspaceBash, deps, args)
451+
452+
// Should not error but should timeout
453+
require.NoError(t, err)
454+
455+
t.Logf("Background with timeout result: exitCode=%d, output=%q", result.ExitCode, result.Output)
456+
457+
// Should have timeout exit code
458+
require.Equal(t, 124, result.ExitCode)
459+
460+
// Should capture output before timeout
461+
require.Contains(t, result.Output, "started", "Should contain output captured before timeout")
462+
463+
// Should contain background continuation message
464+
require.Contains(t, result.Output, "Command continues running in background")
465+
466+
// Wait for the background command to complete (even though SSH session timed out)
467+
require.Eventually(t, func() bool {
468+
checkArgs := toolsdk.WorkspaceBashArgs{
469+
Workspace: workspace.Name,
470+
Command: `cat /tmp/bg-test-done 2>/dev/null || echo "not found"`,
471+
}
472+
checkResult, err := toolsdk.WorkspaceBash.Handler(t.Context(), deps, checkArgs)
473+
return err == nil && checkResult.Output == "done"
474+
}, testutil.WaitMedium, testutil.IntervalMedium, "Background command should continue running and complete after timeout")
475+
})
476+
}

codersdk/toolsdk/toolsdk_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ var testedTools sync.Map
456456
// This is to mimic how we expect external callers to use the tool.
457457
func testTool[Arg, Ret any](t *testing.T, tool toolsdk.Tool[Arg, Ret], tb toolsdk.Deps, args Arg) (Ret, error) {
458458
t.Helper()
459-
defer func() { testedTools.Store(tool.Tool.Name, true) }()
459+
defer func() { testedTools.Store(tool.Name, true) }()
460460
toolArgs, err := json.Marshal(args)
461461
require.NoError(t, err, "failed to marshal args")
462462
result, err := tool.Generic().Handler(t.Context(), tb, toolArgs)
@@ -625,23 +625,23 @@ func TestToolSchemaFields(t *testing.T) {
625625

626626
// Test that all tools have the required Schema fields (Properties and Required)
627627
for _, tool := range toolsdk.All {
628-
t.Run(tool.Tool.Name, func(t *testing.T) {
628+
t.Run(tool.Name, func(t *testing.T) {
629629
t.Parallel()
630630

631631
// Check that Properties is not nil
632-
require.NotNil(t, tool.Tool.Schema.Properties,
633-
"Tool %q missing Schema.Properties", tool.Tool.Name)
632+
require.NotNil(t, tool.Schema.Properties,
633+
"Tool %q missing Schema.Properties", tool.Name)
634634

635635
// Check that Required is not nil
636-
require.NotNil(t, tool.Tool.Schema.Required,
637-
"Tool %q missing Schema.Required", tool.Tool.Name)
636+
require.NotNil(t, tool.Schema.Required,
637+
"Tool %q missing Schema.Required", tool.Name)
638638

639639
// Ensure Properties has entries for all required fields
640-
for _, requiredField := range tool.Tool.Schema.Required {
641-
_, exists := tool.Tool.Schema.Properties[requiredField]
640+
for _, requiredField := range tool.Schema.Required {
641+
_, exists := tool.Schema.Properties[requiredField]
642642
require.True(t, exists,
643643
"Tool %q requires field %q but it is not defined in Properties",
644-
tool.Tool.Name, requiredField)
644+
tool.Name, requiredField)
645645
}
646646
})
647647
}
@@ -652,16 +652,16 @@ func TestToolSchemaFields(t *testing.T) {
652652
func TestMain(m *testing.M) {
653653
// Initialize testedTools
654654
for _, tool := range toolsdk.All {
655-
testedTools.Store(tool.Tool.Name, false)
655+
testedTools.Store(tool.Name, false)
656656
}
657657

658658
code := m.Run()
659659

660660
// Ensure all tools have been tested
661661
var untested []string
662662
for _, tool := range toolsdk.All {
663-
if tested, ok := testedTools.Load(tool.Tool.Name); !ok || !tested.(bool) {
664-
untested = append(untested, tool.Tool.Name)
663+
if tested, ok := testedTools.Load(tool.Name); !ok || !tested.(bool) {
664+
untested = append(untested, tool.Name)
665665
}
666666
}
667667

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy