Skip to content

Commit 398e80f

Browse files
authored
feat: add timeout support to workspace bash tool (#19035)
# Add timeout support to workspace bash tool This PR adds a timeout feature to the workspace bash tool, allowing users to specify a maximum execution time for commands. Key changes include: - Added a `timeout_ms` parameter to control command execution time (defaults to 60 seconds, with a maximum of 5 minutes) - Implemented a new `executeCommandWithTimeout` function that properly handles command timeouts - Added proper output capturing during timeout scenarios, returning all output collected before the timeout - Updated documentation to explain the timeout feature and provide usage examples - Added comprehensive tests for the timeout functionality, including integration tests When a command times out, the tool now returns all captured output up to that point along with a cancellation message, making it clear to users what happened. Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent d159578 commit 398e80f

File tree

2 files changed

+321
-11
lines changed

2 files changed

+321
-11
lines changed

codersdk/toolsdk/bash.go

Lines changed: 141 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package toolsdk
22

33
import (
4+
"bytes"
45
"context"
56
"errors"
67
"fmt"
78
"io"
89
"strings"
10+
"sync"
11+
"time"
912

1013
gossh "golang.org/x/crypto/ssh"
1114
"golang.org/x/xerrors"
@@ -20,6 +23,7 @@ import (
2023
type WorkspaceBashArgs struct {
2124
Workspace string `json:"workspace"`
2225
Command string `json:"command"`
26+
TimeoutMs int `json:"timeout_ms,omitempty"`
2327
}
2428

2529
type WorkspaceBashResult struct {
@@ -43,9 +47,12 @@ The workspace parameter supports various formats:
4347
- workspace.agent (specific agent)
4448
- owner/workspace.agent
4549
50+
The timeout_ms parameter specifies the command timeout in milliseconds (defaults to 60000ms, maximum of 300000ms).
51+
If the command times out, all output captured up to that point is returned with a cancellation message.
52+
4653
Examples:
4754
- workspace: "my-workspace", command: "ls -la"
48-
- workspace: "john/dev-env", command: "git status"
55+
- workspace: "john/dev-env", command: "git status", timeout_ms: 30000
4956
- workspace: "my-workspace.main", command: "docker ps"`,
5057
Schema: aisdk.Schema{
5158
Properties: map[string]any{
@@ -57,18 +64,27 @@ Examples:
5764
"type": "string",
5865
"description": "The bash command to execute in the workspace.",
5966
},
67+
"timeout_ms": map[string]any{
68+
"type": "integer",
69+
"description": "Command timeout in milliseconds. Defaults to 60000ms (60 seconds) if not specified.",
70+
"default": 60000,
71+
"minimum": 1,
72+
},
6073
},
6174
Required: []string{"workspace", "command"},
6275
},
6376
},
64-
Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (WorkspaceBashResult, error) {
77+
Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (res WorkspaceBashResult, err error) {
6578
if args.Workspace == "" {
6679
return WorkspaceBashResult{}, xerrors.New("workspace name cannot be empty")
6780
}
6881
if args.Command == "" {
6982
return WorkspaceBashResult{}, xerrors.New("command cannot be empty")
7083
}
7184

85+
ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, xerrors.New("MCP handler timeout after 5 min"))
86+
defer cancel()
87+
7288
// Normalize workspace input to handle various formats
7389
workspaceName := NormalizeWorkspaceInput(args.Workspace)
7490

@@ -119,23 +135,42 @@ Examples:
119135
}
120136
defer session.Close()
121137

122-
// Execute command and capture output
123-
output, err := session.CombinedOutput(args.Command)
138+
// Set default timeout if not specified (60 seconds)
139+
timeoutMs := args.TimeoutMs
140+
if timeoutMs <= 0 {
141+
timeoutMs = 60000
142+
}
143+
144+
// Create context with timeout
145+
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond)
146+
defer cancel()
147+
148+
// Execute command with timeout handling
149+
output, err := executeCommandWithTimeout(ctx, session, args.Command)
124150
outputStr := strings.TrimSpace(string(output))
125151

152+
// Handle command execution results
126153
if err != nil {
127-
// Check if it's an SSH exit error to get the exit code
128-
var exitErr *gossh.ExitError
129-
if errors.As(err, &exitErr) {
154+
// Check if the command timed out
155+
if errors.Is(context.Cause(ctx), context.DeadlineExceeded) {
156+
outputStr += "\nCommand canceled due to timeout"
130157
return WorkspaceBashResult{
131158
Output: outputStr,
132-
ExitCode: exitErr.ExitStatus(),
159+
ExitCode: 124,
133160
}, nil
134161
}
135-
// For other errors, return exit code 1
162+
163+
// Extract exit code from SSH error if available
164+
exitCode := 1
165+
var exitErr *gossh.ExitError
166+
if errors.As(err, &exitErr) {
167+
exitCode = exitErr.ExitStatus()
168+
}
169+
170+
// For other errors, use standard timeout or generic error code
136171
return WorkspaceBashResult{
137172
Output: outputStr,
138-
ExitCode: 1,
173+
ExitCode: exitCode,
139174
}, nil
140175
}
141176

@@ -292,3 +327,99 @@ func NormalizeWorkspaceInput(input string) string {
292327

293328
return normalized
294329
}
330+
331+
// executeCommandWithTimeout executes a command with timeout support
332+
func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, command string) ([]byte, error) {
333+
// Set up pipes to capture output
334+
stdoutPipe, err := session.StdoutPipe()
335+
if err != nil {
336+
return nil, xerrors.Errorf("failed to create stdout pipe: %w", err)
337+
}
338+
339+
stderrPipe, err := session.StderrPipe()
340+
if err != nil {
341+
return nil, xerrors.Errorf("failed to create stderr pipe: %w", err)
342+
}
343+
344+
// Start the command
345+
if err := session.Start(command); err != nil {
346+
return nil, xerrors.Errorf("failed to start command: %w", err)
347+
}
348+
349+
// Create a thread-safe buffer for combined output
350+
var output bytes.Buffer
351+
var mu sync.Mutex
352+
safeWriter := &syncWriter{w: &output, mu: &mu}
353+
354+
// Use io.MultiWriter to combine stdout and stderr
355+
multiWriter := io.MultiWriter(safeWriter)
356+
357+
// Channel to signal when command completes
358+
done := make(chan error, 1)
359+
360+
// Start goroutine to copy output and wait for completion
361+
go func() {
362+
// Copy stdout and stderr concurrently
363+
var wg sync.WaitGroup
364+
wg.Add(2)
365+
366+
go func() {
367+
defer wg.Done()
368+
_, _ = io.Copy(multiWriter, stdoutPipe)
369+
}()
370+
371+
go func() {
372+
defer wg.Done()
373+
_, _ = io.Copy(multiWriter, stderrPipe)
374+
}()
375+
376+
// Wait for all output to be copied
377+
wg.Wait()
378+
379+
// Wait for the command to complete
380+
done <- session.Wait()
381+
}()
382+
383+
// Wait for either completion or context cancellation
384+
select {
385+
case err := <-done:
386+
// Command completed normally
387+
return safeWriter.Bytes(), err
388+
case <-ctx.Done():
389+
// Context was canceled (timeout or other cancellation)
390+
// Close the session to stop the command
391+
_ = session.Close()
392+
393+
// Give a brief moment to collect any remaining output
394+
timer := time.NewTimer(50 * time.Millisecond)
395+
defer timer.Stop()
396+
397+
select {
398+
case <-timer.C:
399+
// Timer expired, return what we have
400+
case err := <-done:
401+
// Command finished during grace period
402+
return safeWriter.Bytes(), err
403+
}
404+
405+
return safeWriter.Bytes(), context.Cause(ctx)
406+
}
407+
}
408+
409+
// syncWriter is a thread-safe writer
410+
type syncWriter struct {
411+
w *bytes.Buffer
412+
mu *sync.Mutex
413+
}
414+
415+
func (sw *syncWriter) Write(p []byte) (n int, err error) {
416+
sw.mu.Lock()
417+
defer sw.mu.Unlock()
418+
return sw.w.Write(p)
419+
}
420+
421+
func (sw *syncWriter) Bytes() []byte {
422+
sw.mu.Lock()
423+
defer sw.mu.Unlock()
424+
return sw.w.Bytes()
425+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy