Skip to content

Commit 43b0bb7

Browse files
feat(site): use websocket connection for devcontainer updates (#18808)
Instead of polling every 10 seconds, we instead use a WebSocket connection for more timely updates.
1 parent 7cf3263 commit 43b0bb7

File tree

15 files changed

+1079
-23
lines changed

15 files changed

+1079
-23
lines changed

agent/agentcontainers/api.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ package agentcontainers
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
67
"fmt"
8+
"maps"
79
"net/http"
810
"os"
911
"path"
@@ -30,6 +32,7 @@ import (
3032
"github.com/coder/coder/v2/codersdk/agentsdk"
3133
"github.com/coder/coder/v2/provisioner"
3234
"github.com/coder/quartz"
35+
"github.com/coder/websocket"
3336
)
3437

3538
const (
@@ -74,6 +77,7 @@ type API struct {
7477

7578
mu sync.RWMutex // Protects the following fields.
7679
initDone chan struct{} // Closed by Init.
80+
updateChans []chan struct{}
7781
closed bool
7882
containers codersdk.WorkspaceAgentListContainersResponse // Output from the last list operation.
7983
containersErr error // Error from the last list operation.
@@ -535,6 +539,7 @@ func (api *API) Routes() http.Handler {
535539
r.Use(ensureInitDoneMW)
536540

537541
r.Get("/", api.handleList)
542+
r.Get("/watch", api.watchContainers)
538543
// TODO(mafredri): Simplify this route as the previous /devcontainers
539544
// /-route was dropped. We can drop the /devcontainers prefix here too.
540545
r.Route("/devcontainers/{devcontainer}", func(r chi.Router) {
@@ -544,6 +549,88 @@ func (api *API) Routes() http.Handler {
544549
return r
545550
}
546551

552+
func (api *API) broadcastUpdatesLocked() {
553+
// Broadcast state changes to WebSocket listeners.
554+
for _, ch := range api.updateChans {
555+
select {
556+
case ch <- struct{}{}:
557+
default:
558+
}
559+
}
560+
}
561+
562+
func (api *API) watchContainers(rw http.ResponseWriter, r *http.Request) {
563+
ctx := r.Context()
564+
565+
conn, err := websocket.Accept(rw, r, nil)
566+
if err != nil {
567+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
568+
Message: "Failed to upgrade connection to websocket.",
569+
Detail: err.Error(),
570+
})
571+
return
572+
}
573+
574+
// Here we close the websocket for reading, so that the websocket library will handle pings and
575+
// close frames.
576+
_ = conn.CloseRead(context.Background())
577+
578+
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
579+
defer wsNetConn.Close()
580+
581+
go httpapi.Heartbeat(ctx, conn)
582+
583+
updateCh := make(chan struct{}, 1)
584+
585+
api.mu.Lock()
586+
api.updateChans = append(api.updateChans, updateCh)
587+
api.mu.Unlock()
588+
589+
defer func() {
590+
api.mu.Lock()
591+
api.updateChans = slices.DeleteFunc(api.updateChans, func(ch chan struct{}) bool {
592+
return ch == updateCh
593+
})
594+
close(updateCh)
595+
api.mu.Unlock()
596+
}()
597+
598+
encoder := json.NewEncoder(wsNetConn)
599+
600+
ct, err := api.getContainers()
601+
if err != nil {
602+
api.logger.Error(ctx, "unable to get containers", slog.Error(err))
603+
return
604+
}
605+
606+
if err := encoder.Encode(ct); err != nil {
607+
api.logger.Error(ctx, "encode container list", slog.Error(err))
608+
return
609+
}
610+
611+
for {
612+
select {
613+
case <-api.ctx.Done():
614+
return
615+
616+
case <-ctx.Done():
617+
return
618+
619+
case <-updateCh:
620+
ct, err := api.getContainers()
621+
if err != nil {
622+
api.logger.Error(ctx, "unable to get containers", slog.Error(err))
623+
continue
624+
}
625+
626+
if err := encoder.Encode(ct); err != nil {
627+
api.logger.Error(ctx, "encode container list", slog.Error(err))
628+
return
629+
}
630+
}
631+
}
632+
}
633+
547634
// handleList handles the HTTP request to list containers.
548635
func (api *API) handleList(rw http.ResponseWriter, r *http.Request) {
549636
ct, err := api.getContainers()
@@ -583,8 +670,26 @@ func (api *API) updateContainers(ctx context.Context) error {
583670
api.mu.Lock()
584671
defer api.mu.Unlock()
585672

673+
var previouslyKnownDevcontainers map[string]codersdk.WorkspaceAgentDevcontainer
674+
if len(api.updateChans) > 0 {
675+
previouslyKnownDevcontainers = maps.Clone(api.knownDevcontainers)
676+
}
677+
586678
api.processUpdatedContainersLocked(ctx, updated)
587679

680+
if len(api.updateChans) > 0 {
681+
statesAreEqual := maps.EqualFunc(
682+
previouslyKnownDevcontainers,
683+
api.knownDevcontainers,
684+
func(dc1, dc2 codersdk.WorkspaceAgentDevcontainer) bool {
685+
return dc1.Equals(dc2)
686+
})
687+
688+
if !statesAreEqual {
689+
api.broadcastUpdatesLocked()
690+
}
691+
}
692+
588693
api.logger.Debug(ctx, "containers updated successfully", slog.F("container_count", len(api.containers.Containers)), slog.F("warning_count", len(api.containers.Warnings)), slog.F("devcontainer_count", len(api.knownDevcontainers)))
589694

590695
return nil
@@ -955,6 +1060,8 @@ func (api *API) handleDevcontainerRecreate(w http.ResponseWriter, r *http.Reques
9551060
dc.Container = nil
9561061
dc.Error = ""
9571062
api.knownDevcontainers[dc.WorkspaceFolder] = dc
1063+
api.broadcastUpdatesLocked()
1064+
9581065
go func() {
9591066
_ = api.CreateDevcontainer(dc.WorkspaceFolder, dc.ConfigPath, WithRemoveExistingContainer())
9601067
}()
@@ -1070,6 +1177,7 @@ func (api *API) CreateDevcontainer(workspaceFolder, configPath string, opts ...D
10701177
dc.Error = ""
10711178
api.recreateSuccessTimes[dc.WorkspaceFolder] = api.clock.Now("agentcontainers", "recreate", "successTimes")
10721179
api.knownDevcontainers[dc.WorkspaceFolder] = dc
1180+
api.broadcastUpdatesLocked()
10731181
api.mu.Unlock()
10741182

10751183
// Ensure an immediate refresh to accurately reflect the

agent/agentcontainers/api_test.go

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
"github.com/coder/coder/v2/pty"
3737
"github.com/coder/coder/v2/testutil"
3838
"github.com/coder/quartz"
39+
"github.com/coder/websocket"
3940
)
4041

4142
// fakeContainerCLI implements the agentcontainers.ContainerCLI interface for
@@ -441,6 +442,178 @@ func TestAPI(t *testing.T) {
441442
logbuf.Reset()
442443
})
443444

445+
t.Run("Watch", func(t *testing.T) {
446+
t.Parallel()
447+
448+
fakeContainer1 := fakeContainer(t, func(c *codersdk.WorkspaceAgentContainer) {
449+
c.ID = "container1"
450+
c.FriendlyName = "devcontainer1"
451+
c.Image = "busybox:latest"
452+
c.Labels = map[string]string{
453+
agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project1",
454+
agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project1/.devcontainer/devcontainer.json",
455+
}
456+
})
457+
458+
fakeContainer2 := fakeContainer(t, func(c *codersdk.WorkspaceAgentContainer) {
459+
c.ID = "container2"
460+
c.FriendlyName = "devcontainer2"
461+
c.Image = "ubuntu:latest"
462+
c.Labels = map[string]string{
463+
agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project2",
464+
agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project2/.devcontainer/devcontainer.json",
465+
}
466+
})
467+
468+
stages := []struct {
469+
containers []codersdk.WorkspaceAgentContainer
470+
expected codersdk.WorkspaceAgentListContainersResponse
471+
}{
472+
{
473+
containers: []codersdk.WorkspaceAgentContainer{fakeContainer1},
474+
expected: codersdk.WorkspaceAgentListContainersResponse{
475+
Containers: []codersdk.WorkspaceAgentContainer{fakeContainer1},
476+
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
477+
{
478+
Name: "project1",
479+
WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel],
480+
ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel],
481+
Status: "running",
482+
Container: &fakeContainer1,
483+
},
484+
},
485+
},
486+
},
487+
{
488+
containers: []codersdk.WorkspaceAgentContainer{fakeContainer1, fakeContainer2},
489+
expected: codersdk.WorkspaceAgentListContainersResponse{
490+
Containers: []codersdk.WorkspaceAgentContainer{fakeContainer1, fakeContainer2},
491+
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
492+
{
493+
Name: "project1",
494+
WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel],
495+
ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel],
496+
Status: "running",
497+
Container: &fakeContainer1,
498+
},
499+
{
500+
Name: "project2",
501+
WorkspaceFolder: fakeContainer2.Labels[agentcontainers.DevcontainerLocalFolderLabel],
502+
ConfigPath: fakeContainer2.Labels[agentcontainers.DevcontainerConfigFileLabel],
503+
Status: "running",
504+
Container: &fakeContainer2,
505+
},
506+
},
507+
},
508+
},
509+
{
510+
containers: []codersdk.WorkspaceAgentContainer{fakeContainer2},
511+
expected: codersdk.WorkspaceAgentListContainersResponse{
512+
Containers: []codersdk.WorkspaceAgentContainer{fakeContainer2},
513+
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
514+
{
515+
Name: "",
516+
WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel],
517+
ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel],
518+
Status: "stopped",
519+
Container: nil,
520+
},
521+
{
522+
Name: "project2",
523+
WorkspaceFolder: fakeContainer2.Labels[agentcontainers.DevcontainerLocalFolderLabel],
524+
ConfigPath: fakeContainer2.Labels[agentcontainers.DevcontainerConfigFileLabel],
525+
Status: "running",
526+
Container: &fakeContainer2,
527+
},
528+
},
529+
},
530+
},
531+
}
532+
533+
var (
534+
ctx = testutil.Context(t, testutil.WaitShort)
535+
mClock = quartz.NewMock(t)
536+
updaterTickerTrap = mClock.Trap().TickerFunc("updaterLoop")
537+
mCtrl = gomock.NewController(t)
538+
mLister = acmock.NewMockContainerCLI(mCtrl)
539+
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
540+
)
541+
542+
// Set up initial state for immediate send on connection
543+
mLister.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: stages[0].containers}, nil)
544+
mLister.EXPECT().DetectArchitecture(gomock.Any(), gomock.Any()).Return("<none>", nil).AnyTimes()
545+
546+
api := agentcontainers.NewAPI(logger,
547+
agentcontainers.WithClock(mClock),
548+
agentcontainers.WithContainerCLI(mLister),
549+
agentcontainers.WithWatcher(watcher.NewNoop()),
550+
)
551+
api.Start()
552+
defer api.Close()
553+
554+
srv := httptest.NewServer(api.Routes())
555+
defer srv.Close()
556+
557+
updaterTickerTrap.MustWait(ctx).MustRelease(ctx)
558+
defer updaterTickerTrap.Close()
559+
560+
client, res, err := websocket.Dial(ctx, srv.URL+"/watch", nil)
561+
require.NoError(t, err)
562+
if res != nil && res.Body != nil {
563+
defer res.Body.Close()
564+
}
565+
566+
// Read initial state sent immediately on connection
567+
mt, msg, err := client.Read(ctx)
568+
require.NoError(t, err)
569+
require.Equal(t, websocket.MessageText, mt)
570+
571+
var got codersdk.WorkspaceAgentListContainersResponse
572+
err = json.Unmarshal(msg, &got)
573+
require.NoError(t, err)
574+
575+
require.Equal(t, stages[0].expected.Containers, got.Containers)
576+
require.Len(t, got.Devcontainers, len(stages[0].expected.Devcontainers))
577+
for j, expectedDev := range stages[0].expected.Devcontainers {
578+
gotDev := got.Devcontainers[j]
579+
require.Equal(t, expectedDev.Name, gotDev.Name)
580+
require.Equal(t, expectedDev.WorkspaceFolder, gotDev.WorkspaceFolder)
581+
require.Equal(t, expectedDev.ConfigPath, gotDev.ConfigPath)
582+
require.Equal(t, expectedDev.Status, gotDev.Status)
583+
require.Equal(t, expectedDev.Container, gotDev.Container)
584+
}
585+
586+
// Process remaining stages through updater loop
587+
for i, stage := range stages[1:] {
588+
mLister.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: stage.containers}, nil)
589+
590+
// Given: We allow the update loop to progress
591+
_, aw := mClock.AdvanceNext()
592+
aw.MustWait(ctx)
593+
594+
// When: We attempt to read a message from the socket.
595+
mt, msg, err := client.Read(ctx)
596+
require.NoError(t, err)
597+
require.Equal(t, websocket.MessageText, mt)
598+
599+
// Then: We expect the receieved message matches the expected response.
600+
var got codersdk.WorkspaceAgentListContainersResponse
601+
err = json.Unmarshal(msg, &got)
602+
require.NoError(t, err)
603+
604+
require.Equal(t, stages[i+1].expected.Containers, got.Containers)
605+
require.Len(t, got.Devcontainers, len(stages[i+1].expected.Devcontainers))
606+
for j, expectedDev := range stages[i+1].expected.Devcontainers {
607+
gotDev := got.Devcontainers[j]
608+
require.Equal(t, expectedDev.Name, gotDev.Name)
609+
require.Equal(t, expectedDev.WorkspaceFolder, gotDev.WorkspaceFolder)
610+
require.Equal(t, expectedDev.ConfigPath, gotDev.ConfigPath)
611+
require.Equal(t, expectedDev.Status, gotDev.Status)
612+
require.Equal(t, expectedDev.Container, gotDev.Container)
613+
}
614+
}
615+
})
616+
444617
// List tests the API.getContainers method using a mock
445618
// implementation. It specifically tests caching behavior.
446619
t.Run("List", func(t *testing.T) {

coderd/apidoc/docs.go

Lines changed: 35 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy