Skip to content

Commit 2acf0ad

Browse files
authored
chore(codersdk/toolsdk): improve static analyzability of toolsdk.Tools (#17562)
* Refactors toolsdk.Tools to remove opaque `map[string]any` argument in favour of typed args structs. * Refactors toolsdk.Tools to remove opaque passing of dependencies via `context.Context` in favour of a tool dependencies struct. * Adds panic recovery and clean context middleware to all tools. * Adds `GenericTool` implementation to allow keeping `toolsdk.All` with uniform type signature while maintaining type information in handlers. * Adds stricter checks to `patchWorkspaceAgentAppStatus` handler.
1 parent 1fc74f6 commit 2acf0ad

File tree

6 files changed

+1139
-865
lines changed

6 files changed

+1139
-865
lines changed

cli/exp_mcp.go

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cli
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"errors"
@@ -427,22 +428,27 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct
427428
server.WithInstructions(instructions),
428429
)
429430

430-
// Create a new context for the tools with all relevant information.
431-
clientCtx := toolsdk.WithClient(ctx, client)
432431
// Get the workspace agent token from the environment.
432+
toolOpts := make([]func(*toolsdk.Deps), 0)
433433
var hasAgentClient bool
434434
if agentToken, err := getAgentToken(fs); err == nil && agentToken != "" {
435435
hasAgentClient = true
436436
agentClient := agentsdk.New(client.URL)
437437
agentClient.SetSessionToken(agentToken)
438-
clientCtx = toolsdk.WithAgentClient(clientCtx, agentClient)
438+
toolOpts = append(toolOpts, toolsdk.WithAgentClient(agentClient))
439439
} else {
440440
cliui.Warnf(inv.Stderr, "CODER_AGENT_TOKEN is not set, task reporting will not be available")
441441
}
442-
if appStatusSlug == "" {
443-
cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.")
442+
443+
if appStatusSlug != "" {
444+
toolOpts = append(toolOpts, toolsdk.WithAppStatusSlug(appStatusSlug))
444445
} else {
445-
clientCtx = toolsdk.WithWorkspaceAppStatusSlug(clientCtx, appStatusSlug)
446+
cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.")
447+
}
448+
449+
toolDeps, err := toolsdk.NewDeps(client, toolOpts...)
450+
if err != nil {
451+
return xerrors.Errorf("failed to initialize tool dependencies: %w", err)
446452
}
447453

448454
// Register tools based on the allowlist (if specified)
@@ -455,15 +461,15 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct
455461
if len(allowedTools) == 0 || slices.ContainsFunc(allowedTools, func(t string) bool {
456462
return t == tool.Tool.Name
457463
}) {
458-
mcpSrv.AddTools(mcpFromSDK(tool))
464+
mcpSrv.AddTools(mcpFromSDK(tool, toolDeps))
459465
}
460466
}
461467

462468
srv := server.NewStdioServer(mcpSrv)
463469
done := make(chan error)
464470
go func() {
465471
defer close(done)
466-
srvErr := srv.Listen(clientCtx, invStdin, invStdout)
472+
srvErr := srv.Listen(ctx, invStdin, invStdout)
467473
done <- srvErr
468474
}()
469475

@@ -726,7 +732,7 @@ func getAgentToken(fs afero.Fs) (string, error) {
726732

727733
// mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool.
728734
// It assumes that the tool responds with a valid JSON object.
729-
func mcpFromSDK(sdkTool toolsdk.Tool[any]) server.ServerTool {
735+
func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool {
730736
// NOTE: some clients will silently refuse to use tools if there is an issue
731737
// with the tool's schema or configuration.
732738
if sdkTool.Schema.Properties == nil {
@@ -743,27 +749,17 @@ func mcpFromSDK(sdkTool toolsdk.Tool[any]) server.ServerTool {
743749
},
744750
},
745751
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
746-
result, err := sdkTool.Handler(ctx, request.Params.Arguments)
752+
var buf bytes.Buffer
753+
if err := json.NewEncoder(&buf).Encode(request.Params.Arguments); err != nil {
754+
return nil, xerrors.Errorf("failed to encode request arguments: %w", err)
755+
}
756+
result, err := sdkTool.Handler(ctx, tb, buf.Bytes())
747757
if err != nil {
748758
return nil, err
749759
}
750-
var sb strings.Builder
751-
if err := json.NewEncoder(&sb).Encode(result); err == nil {
752-
return &mcp.CallToolResult{
753-
Content: []mcp.Content{
754-
mcp.NewTextContent(sb.String()),
755-
},
756-
}, nil
757-
}
758-
// If the result is not JSON, return it as a string.
759-
// This is a fallback for tools that return non-JSON data.
760-
resultStr, ok := result.(string)
761-
if !ok {
762-
return nil, xerrors.Errorf("tool call result is neither valid JSON or a string, got: %T", result)
763-
}
764760
return &mcp.CallToolResult{
765761
Content: []mcp.Content{
766-
mcp.NewTextContent(resultStr),
762+
mcp.NewTextContent(string(result)),
767763
},
768764
}, nil
769765
},

cli/exp_mcp_test.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ func TestExpMcpServer(t *testing.T) {
3131
t.Parallel()
3232

3333
ctx := testutil.Context(t, testutil.WaitShort)
34+
cmdDone := make(chan struct{})
3435
cancelCtx, cancel := context.WithCancel(ctx)
35-
t.Cleanup(cancel)
3636

3737
// Given: a running coder deployment
3838
client := coderdtest.New(t, nil)
39-
_ = coderdtest.CreateFirstUser(t, client)
39+
owner := coderdtest.CreateFirstUser(t, client)
4040

4141
// Given: we run the exp mcp command with allowed tools set
4242
inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_get_authenticated_user")
@@ -48,7 +48,6 @@ func TestExpMcpServer(t *testing.T) {
4848
// nolint: gocritic // not the focus of this test
4949
clitest.SetupConfig(t, client, root)
5050

51-
cmdDone := make(chan struct{})
5251
go func() {
5352
defer close(cmdDone)
5453
err := inv.Run()
@@ -61,9 +60,6 @@ func TestExpMcpServer(t *testing.T) {
6160
_ = pty.ReadLine(ctx) // ignore echoed output
6261
output := pty.ReadLine(ctx)
6362

64-
cancel()
65-
<-cmdDone
66-
6763
// Then: we should only see the allowed tools in the response
6864
var toolsResponse struct {
6965
Result struct {
@@ -81,6 +77,20 @@ func TestExpMcpServer(t *testing.T) {
8177
}
8278
slices.Sort(foundTools)
8379
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
80+
81+
// Call the tool and ensure it works.
82+
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
83+
pty.WriteLine(toolPayload)
84+
_ = pty.ReadLine(ctx) // ignore echoed output
85+
output = pty.ReadLine(ctx)
86+
require.NotEmpty(t, output, "should have received a response from the tool")
87+
// Ensure it's valid JSON
88+
_, err = json.Marshal(output)
89+
require.NoError(t, err, "should have received a valid JSON response from the tool")
90+
// Ensure the tool returns the expected user
91+
require.Contains(t, output, owner.UserID.String(), "should have received the expected user ID")
92+
cancel()
93+
<-cmdDone
8494
})
8595

8696
t.Run("OK", func(t *testing.T) {

coderd/workspaceagents.go

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,33 @@ func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Req
338338
Slug: req.AppSlug,
339339
})
340340
if err != nil {
341-
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
341+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
342342
Message: "Failed to get workspace app.",
343-
Detail: err.Error(),
343+
Detail: fmt.Sprintf("No app found with slug %q", req.AppSlug),
344+
})
345+
return
346+
}
347+
348+
if len(req.Message) > 160 {
349+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
350+
Message: "Message is too long.",
351+
Detail: "Message must be less than 160 characters.",
352+
Validations: []codersdk.ValidationError{
353+
{Field: "message", Detail: "Message must be less than 160 characters."},
354+
},
355+
})
356+
return
357+
}
358+
359+
switch req.State {
360+
case codersdk.WorkspaceAppStatusStateComplete, codersdk.WorkspaceAppStatusStateFailure, codersdk.WorkspaceAppStatusStateWorking: // valid states
361+
default:
362+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
363+
Message: "Invalid state provided.",
364+
Detail: fmt.Sprintf("invalid state: %q", req.State),
365+
Validations: []codersdk.ValidationError{
366+
{Field: "state", Detail: "State must be one of: complete, failure, working."},
367+
},
344368
})
345369
return
346370
}

coderd/workspaceagents_test.go

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -340,27 +340,27 @@ func TestWorkspaceAgentLogs(t *testing.T) {
340340

341341
func TestWorkspaceAgentAppStatus(t *testing.T) {
342342
t.Parallel()
343-
t.Run("Success", func(t *testing.T) {
344-
t.Parallel()
345-
ctx := testutil.Context(t, testutil.WaitMedium)
346-
client, db := coderdtest.NewWithDatabase(t, nil)
347-
user := coderdtest.CreateFirstUser(t, client)
348-
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
343+
client, db := coderdtest.NewWithDatabase(t, nil)
344+
user := coderdtest.CreateFirstUser(t, client)
345+
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
349346

350-
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
351-
OrganizationID: user.OrganizationID,
352-
OwnerID: user2.ID,
353-
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
354-
a[0].Apps = []*proto.App{
355-
{
356-
Slug: "vscode",
357-
},
358-
}
359-
return a
360-
}).Do()
347+
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
348+
OrganizationID: user.OrganizationID,
349+
OwnerID: user2.ID,
350+
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
351+
a[0].Apps = []*proto.App{
352+
{
353+
Slug: "vscode",
354+
},
355+
}
356+
return a
357+
}).Do()
361358

362-
agentClient := agentsdk.New(client.URL)
363-
agentClient.SetSessionToken(r.AgentToken)
359+
agentClient := agentsdk.New(client.URL)
360+
agentClient.SetSessionToken(r.AgentToken)
361+
t.Run("Success", func(t *testing.T) {
362+
t.Parallel()
363+
ctx := testutil.Context(t, testutil.WaitShort)
364364
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
365365
AppSlug: "vscode",
366366
Message: "testing",
@@ -381,6 +381,51 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
381381
require.Empty(t, agent.Apps[0].Statuses[0].Icon)
382382
require.False(t, agent.Apps[0].Statuses[0].NeedsUserAttention)
383383
})
384+
385+
t.Run("FailUnknownApp", func(t *testing.T) {
386+
t.Parallel()
387+
ctx := testutil.Context(t, testutil.WaitShort)
388+
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
389+
AppSlug: "unknown",
390+
Message: "testing",
391+
URI: "https://example.com",
392+
State: codersdk.WorkspaceAppStatusStateComplete,
393+
})
394+
require.ErrorContains(t, err, "No app found with slug")
395+
var sdkErr *codersdk.Error
396+
require.ErrorAs(t, err, &sdkErr)
397+
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
398+
})
399+
400+
t.Run("FailUnknownState", func(t *testing.T) {
401+
t.Parallel()
402+
ctx := testutil.Context(t, testutil.WaitShort)
403+
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
404+
AppSlug: "vscode",
405+
Message: "testing",
406+
URI: "https://example.com",
407+
State: "unknown",
408+
})
409+
require.ErrorContains(t, err, "Invalid state")
410+
var sdkErr *codersdk.Error
411+
require.ErrorAs(t, err, &sdkErr)
412+
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
413+
})
414+
415+
t.Run("FailTooLong", func(t *testing.T) {
416+
t.Parallel()
417+
ctx := testutil.Context(t, testutil.WaitShort)
418+
err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
419+
AppSlug: "vscode",
420+
Message: strings.Repeat("a", 161),
421+
URI: "https://example.com",
422+
State: codersdk.WorkspaceAppStatusStateComplete,
423+
})
424+
require.ErrorContains(t, err, "Message is too long")
425+
var sdkErr *codersdk.Error
426+
require.ErrorAs(t, err, &sdkErr)
427+
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
428+
})
384429
}
385430

386431
func TestWorkspaceAgentConnectRPC(t *testing.T) {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy