Skip to content

Commit 50998fc

Browse files
authored
Merge branch 'main' into jaaydenh/workspace-creation-fix
2 parents 49951a6 + b61f0ab commit 50998fc

File tree

5 files changed

+191
-57
lines changed

5 files changed

+191
-57
lines changed

agent/agent.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,15 +1773,22 @@ func (a *agent) Close() error {
17731773
a.setLifecycle(codersdk.WorkspaceAgentLifecycleShuttingDown)
17741774

17751775
// Attempt to gracefully shut down all active SSH connections and
1776-
// stop accepting new ones.
1777-
err := a.sshServer.Shutdown(a.hardCtx)
1776+
// stop accepting new ones. If all processes have not exited after 5
1777+
// seconds, we just log it and move on as it's more important to run
1778+
// the shutdown scripts. A typical shutdown time for containers is
1779+
// 10 seconds, so this still leaves a bit of time to run the
1780+
// shutdown scripts in the worst-case.
1781+
sshShutdownCtx, sshShutdownCancel := context.WithTimeout(a.hardCtx, 5*time.Second)
1782+
defer sshShutdownCancel()
1783+
err := a.sshServer.Shutdown(sshShutdownCtx)
17781784
if err != nil {
1779-
a.logger.Error(a.hardCtx, "ssh server shutdown", slog.Error(err))
1780-
}
1781-
err = a.sshServer.Close()
1782-
if err != nil {
1783-
a.logger.Error(a.hardCtx, "ssh server close", slog.Error(err))
1785+
if errors.Is(err, context.DeadlineExceeded) {
1786+
a.logger.Warn(sshShutdownCtx, "ssh server shutdown timeout", slog.Error(err))
1787+
} else {
1788+
a.logger.Error(sshShutdownCtx, "ssh server shutdown", slog.Error(err))
1789+
}
17841790
}
1791+
17851792
// wait for SSH to shut down before the general graceful cancel, because
17861793
// this triggers a disconnect in the tailnet layer, telling all clients to
17871794
// shut down their wireguard tunnels to us. If SSH sessions are still up,

agent/agentssh/agentssh.go

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,12 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, env []str
582582
func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
583583
s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1)
584584

585+
// Create a process group and send SIGHUP to child processes,
586+
// otherwise context cancellation will not propagate properly
587+
// and SSH server close may be delayed.
588+
cmd.SysProcAttr = cmdSysProcAttr()
589+
cmd.Cancel = cmdCancel(session.Context(), logger, cmd)
590+
585591
cmd.Stdout = session
586592
cmd.Stderr = session.Stderr()
587593
// This blocks forever until stdin is received if we don't
@@ -926,7 +932,12 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string,
926932
// Serve starts the server to handle incoming connections on the provided listener.
927933
// It returns an error if no host keys are set or if there is an issue accepting connections.
928934
func (s *Server) Serve(l net.Listener) (retErr error) {
929-
if len(s.srv.HostSigners) == 0 {
935+
// Ensure we're not mutating HostSigners as we're reading it.
936+
s.mu.RLock()
937+
noHostKeys := len(s.srv.HostSigners) == 0
938+
s.mu.RUnlock()
939+
940+
if noHostKeys {
930941
return xerrors.New("no host keys set")
931942
}
932943

@@ -1054,43 +1065,72 @@ func (s *Server) Close() error {
10541065
}
10551066
s.closing = make(chan struct{})
10561067

1068+
ctx := context.Background()
1069+
1070+
s.logger.Debug(ctx, "closing server")
1071+
1072+
// Stop accepting new connections.
1073+
s.logger.Debug(ctx, "closing all active listeners", slog.F("count", len(s.listeners)))
1074+
for l := range s.listeners {
1075+
_ = l.Close()
1076+
}
1077+
10571078
// Close all active sessions to gracefully
10581079
// terminate client connections.
1080+
s.logger.Debug(ctx, "closing all active sessions", slog.F("count", len(s.sessions)))
10591081
for ss := range s.sessions {
10601082
// We call Close on the underlying channel here because we don't
10611083
// want to send an exit status to the client (via Exit()).
10621084
// Typically OpenSSH clients will return 255 as the exit status.
10631085
_ = ss.Close()
10641086
}
1065-
1066-
// Close all active listeners and connections.
1067-
for l := range s.listeners {
1068-
_ = l.Close()
1069-
}
1087+
s.logger.Debug(ctx, "closing all active connections", slog.F("count", len(s.conns)))
10701088
for c := range s.conns {
10711089
_ = c.Close()
10721090
}
10731091

1074-
// Close the underlying SSH server.
1092+
s.logger.Debug(ctx, "closing SSH server")
10751093
err := s.srv.Close()
10761094

10771095
s.mu.Unlock()
1096+
1097+
s.logger.Debug(ctx, "waiting for all goroutines to exit")
10781098
s.wg.Wait() // Wait for all goroutines to exit.
10791099

10801100
s.mu.Lock()
10811101
close(s.closing)
10821102
s.closing = nil
10831103
s.mu.Unlock()
10841104

1105+
s.logger.Debug(ctx, "closing server done")
1106+
10851107
return err
10861108
}
10871109

1088-
// Shutdown gracefully closes all active SSH connections and stops
1089-
// accepting new connections.
1090-
//
1091-
// Shutdown is not implemented.
1092-
func (*Server) Shutdown(_ context.Context) error {
1093-
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
1110+
// Shutdown stops accepting new connections. The current implementation
1111+
// calls Close() for simplicity instead of waiting for existing
1112+
// connections to close. If the context times out, Shutdown will return
1113+
// but Close() may not have completed.
1114+
func (s *Server) Shutdown(ctx context.Context) error {
1115+
ch := make(chan error, 1)
1116+
go func() {
1117+
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
1118+
// For now we just close the server.
1119+
ch <- s.Close()
1120+
}()
1121+
var err error
1122+
select {
1123+
case <-ctx.Done():
1124+
err = ctx.Err()
1125+
case err = <-ch:
1126+
}
1127+
// Re-check for context cancellation precedence.
1128+
if ctx.Err() != nil {
1129+
err = ctx.Err()
1130+
}
1131+
if err != nil {
1132+
return xerrors.Errorf("close server: %w", err)
1133+
}
10941134
return nil
10951135
}
10961136

agent/agentssh/agentssh_test.go

Lines changed: 79 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"go.uber.org/goleak"
2222
"golang.org/x/crypto/ssh"
2323

24+
"cdr.dev/slog"
2425
"cdr.dev/slog/sloggers/slogtest"
2526

2627
"github.com/coder/coder/v2/agent/agentexec"
@@ -147,51 +148,92 @@ func (*fakeEnvInfoer) ModifyCommand(cmd string, args ...string) (string, []strin
147148
func TestNewServer_CloseActiveConnections(t *testing.T) {
148149
t.Parallel()
149150

150-
ctx := context.Background()
151-
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
152-
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
153-
require.NoError(t, err)
154-
defer s.Close()
155-
err = s.UpdateHostSigner(42)
156-
assert.NoError(t, err)
151+
prepare := func(ctx context.Context, t *testing.T) (*agentssh.Server, func()) {
152+
t.Helper()
153+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
154+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
155+
require.NoError(t, err)
156+
defer s.Close()
157+
err = s.UpdateHostSigner(42)
158+
assert.NoError(t, err)
157159

158-
ln, err := net.Listen("tcp", "127.0.0.1:0")
159-
require.NoError(t, err)
160+
ln, err := net.Listen("tcp", "127.0.0.1:0")
161+
require.NoError(t, err)
160162

161-
var wg sync.WaitGroup
162-
wg.Add(2)
163-
go func() {
164-
defer wg.Done()
165-
err := s.Serve(ln)
166-
assert.Error(t, err) // Server is closed.
167-
}()
163+
waitConns := make([]chan struct{}, 4)
168164

169-
pty := ptytest.New(t)
165+
var wg sync.WaitGroup
166+
wg.Add(1 + len(waitConns))
170167

171-
doClose := make(chan struct{})
172-
go func() {
173-
defer wg.Done()
174-
c := sshClient(t, ln.Addr().String())
175-
sess, err := c.NewSession()
176-
assert.NoError(t, err)
177-
sess.Stdin = pty.Input()
178-
sess.Stdout = pty.Output()
179-
sess.Stderr = pty.Output()
168+
go func() {
169+
defer wg.Done()
170+
err := s.Serve(ln)
171+
assert.Error(t, err) // Server is closed.
172+
}()
180173

181-
assert.NoError(t, err)
182-
err = sess.Start("")
183-
assert.NoError(t, err)
174+
for i := 0; i < len(waitConns); i++ {
175+
waitConns[i] = make(chan struct{})
176+
go func(ch chan struct{}) {
177+
defer wg.Done()
178+
c := sshClient(t, ln.Addr().String())
179+
sess, err := c.NewSession()
180+
assert.NoError(t, err)
181+
pty := ptytest.New(t)
182+
sess.Stdin = pty.Input()
183+
sess.Stdout = pty.Output()
184+
sess.Stderr = pty.Output()
185+
186+
// Every other session will request a PTY.
187+
if i%2 == 0 {
188+
err = sess.RequestPty("xterm", 80, 80, nil)
189+
assert.NoError(t, err)
190+
}
191+
// The 60 seconds here is intended to be longer than the
192+
// test. The shutdown should propagate.
193+
err = sess.Start("/bin/bash -c 'trap \"sleep 60\" SIGTERM; sleep 60'")
194+
assert.NoError(t, err)
195+
196+
close(ch)
197+
err = sess.Wait()
198+
assert.Error(t, err)
199+
}(waitConns[i])
200+
}
184201

185-
close(doClose)
186-
err = sess.Wait()
187-
assert.Error(t, err)
188-
}()
202+
for _, ch := range waitConns {
203+
<-ch
204+
}
189205

190-
<-doClose
191-
err = s.Close()
192-
require.NoError(t, err)
206+
return s, wg.Wait
207+
}
208+
209+
t.Run("Close", func(t *testing.T) {
210+
t.Parallel()
211+
ctx := testutil.Context(t, testutil.WaitMedium)
212+
s, wait := prepare(ctx, t)
213+
err := s.Close()
214+
require.NoError(t, err)
215+
wait()
216+
})
193217

194-
wg.Wait()
218+
t.Run("Shutdown", func(t *testing.T) {
219+
t.Parallel()
220+
ctx := testutil.Context(t, testutil.WaitMedium)
221+
s, wait := prepare(ctx, t)
222+
err := s.Shutdown(ctx)
223+
require.NoError(t, err)
224+
wait()
225+
})
226+
227+
t.Run("Shutdown Early", func(t *testing.T) {
228+
t.Parallel()
229+
ctx := testutil.Context(t, testutil.WaitMedium)
230+
s, wait := prepare(ctx, t)
231+
ctx, cancel := context.WithCancel(ctx)
232+
cancel()
233+
err := s.Shutdown(ctx)
234+
require.ErrorIs(t, err, context.Canceled)
235+
wait()
236+
})
195237
}
196238

197239
func TestNewServer_Signal(t *testing.T) {

agent/agentssh/exec_other.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//go:build !windows
2+
3+
package agentssh
4+
5+
import (
6+
"context"
7+
"os/exec"
8+
"syscall"
9+
10+
"cdr.dev/slog"
11+
)
12+
13+
func cmdSysProcAttr() *syscall.SysProcAttr {
14+
return &syscall.SysProcAttr{
15+
Setsid: true,
16+
}
17+
}
18+
19+
func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error {
20+
return func() error {
21+
logger.Debug(ctx, "cmdCancel: sending SIGHUP to process and children", slog.F("pid", cmd.Process.Pid))
22+
return syscall.Kill(-cmd.Process.Pid, syscall.SIGHUP)
23+
}
24+
}

agent/agentssh/exec_windows.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package agentssh
2+
3+
import (
4+
"context"
5+
"os"
6+
"os/exec"
7+
"syscall"
8+
9+
"cdr.dev/slog"
10+
)
11+
12+
func cmdSysProcAttr() *syscall.SysProcAttr {
13+
return &syscall.SysProcAttr{}
14+
}
15+
16+
func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error {
17+
return func() error {
18+
logger.Debug(ctx, "cmdCancel: sending interrupt to process", slog.F("pid", cmd.Process.Pid))
19+
return cmd.Process.Signal(os.Interrupt)
20+
}
21+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy