Skip to content

Commit be118e6

Browse files
committed
reduce scope
1 parent 578ebb0 commit be118e6

File tree

6 files changed

+66
-181
lines changed

6 files changed

+66
-181
lines changed

cli/cliutil/stdioconn.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,31 @@ import (
66
"time"
77
)
88

9-
type StdioConn struct {
9+
type ReaderWriterConn struct {
1010
io.Reader
1111
io.Writer
1212
}
1313

14-
func (*StdioConn) Close() (err error) {
14+
func (*ReaderWriterConn) Close() (err error) {
1515
return nil
1616
}
1717

18-
func (*StdioConn) LocalAddr() net.Addr {
18+
func (*ReaderWriterConn) LocalAddr() net.Addr {
1919
return nil
2020
}
2121

22-
func (*StdioConn) RemoteAddr() net.Addr {
22+
func (*ReaderWriterConn) RemoteAddr() net.Addr {
2323
return nil
2424
}
2525

26-
func (*StdioConn) SetDeadline(_ time.Time) error {
26+
func (*ReaderWriterConn) SetDeadline(_ time.Time) error {
2727
return nil
2828
}
2929

30-
func (*StdioConn) SetReadDeadline(_ time.Time) error {
30+
func (*ReaderWriterConn) SetReadDeadline(_ time.Time) error {
3131
return nil
3232
}
3333

34-
func (*StdioConn) SetWriteDeadline(_ time.Time) error {
34+
func (*ReaderWriterConn) SetWriteDeadline(_ time.Time) error {
3535
return nil
3636
}

cli/ssh.go

Lines changed: 50 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func (r *RootCmd) ssh() *serpent.Command {
6767
stdio bool
6868
hostPrefix string
6969
hostnameSuffix string
70-
forceTunnel bool
70+
forceNewTunnel bool
7171
forwardAgent bool
7272
forwardGPG bool
7373
identityAgent string
@@ -278,27 +278,38 @@ func (r *RootCmd) ssh() *serpent.Command {
278278
return err
279279
}
280280

281-
// See if we can use the Coder Connect tunnel
282-
if !forceTunnel {
281+
// If we're in stdio mode, check to see if we can use Coder Connect.
282+
// We don't support Coder Connect over non-stdio coder ssh yet.
283+
if stdio && !forceNewTunnel {
283284
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
284285
if err != nil {
285286
return xerrors.Errorf("get agent connection info: %w", err)
286287
}
287-
288288
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
289289
workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix)
290290
exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost)
291291
if exists {
292292
_, _ = fmt.Fprintln(inv.Stderr, "Connecting to workspace via Coder Connect...")
293293
defer cancel()
294-
addr := fmt.Sprintf("%s:22", coderConnectHost)
295-
if stdio {
294+
295+
if networkInfoDir != "" {
296296
if err := writeCoderConnectNetInfo(ctx, networkInfoDir); err != nil {
297297
logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err))
298298
}
299-
return runCoderConnectStdio(ctx, addr, stdioReader, stdioWriter, stack)
300299
}
301-
return runCoderConnectPTY(ctx, addr, inv.Stdin, inv.Stdout, inv.Stderr, stack)
300+
301+
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
302+
defer stopPolling()
303+
304+
usageAppName := getUsageAppName(usageApp)
305+
if usageAppName != "" {
306+
closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, workspace.ID, codersdk.PostWorkspaceUsageRequest{
307+
AgentID: workspaceAgent.ID,
308+
AppName: usageAppName,
309+
})
310+
defer closeUsage()
311+
}
312+
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack)
302313
}
303314
}
304315

@@ -481,11 +492,36 @@ func (r *RootCmd) ssh() *serpent.Command {
481492
stdinFile, validIn := inv.Stdin.(*os.File)
482493
stdoutFile, validOut := inv.Stdout.(*os.File)
483494
if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) {
484-
restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, sshSession)
485-
defer restorePtyFn()
495+
inState, err := pty.MakeInputRaw(stdinFile.Fd())
496+
if err != nil {
497+
return err
498+
}
499+
defer func() {
500+
_ = pty.RestoreTerminal(stdinFile.Fd(), inState)
501+
}()
502+
outState, err := pty.MakeOutputRaw(stdoutFile.Fd())
486503
if err != nil {
487-
return xerrors.Errorf("configure pty: %w", err)
504+
return err
488505
}
506+
defer func() {
507+
_ = pty.RestoreTerminal(stdoutFile.Fd(), outState)
508+
}()
509+
510+
windowChange := listenWindowSize(ctx)
511+
go func() {
512+
for {
513+
select {
514+
case <-ctx.Done():
515+
return
516+
case <-windowChange:
517+
}
518+
width, height, err := term.GetSize(int(stdoutFile.Fd()))
519+
if err != nil {
520+
continue
521+
}
522+
_ = sshSession.WindowChange(height, width)
523+
}
524+
}()
489525
}
490526

491527
for _, kv := range parsedEnv {
@@ -667,48 +703,14 @@ func (r *RootCmd) ssh() *serpent.Command {
667703
{
668704
Flag: "force-new-tunnel",
669705
Description: "Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available.",
670-
Value: serpent.BoolOf(&forceTunnel),
706+
Value: serpent.BoolOf(&forceNewTunnel),
707+
Hidden: true,
671708
},
672709
sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)),
673710
}
674711
return cmd
675712
}
676713

677-
func configurePTY(ctx context.Context, stdinFile *os.File, stdoutFile *os.File, sshSession *gossh.Session) (restoreFn func(), err error) {
678-
inState, err := pty.MakeInputRaw(stdinFile.Fd())
679-
if err != nil {
680-
return restoreFn, err
681-
}
682-
restoreFn = func() {
683-
_ = pty.RestoreTerminal(stdinFile.Fd(), inState)
684-
}
685-
outState, err := pty.MakeOutputRaw(stdoutFile.Fd())
686-
if err != nil {
687-
return restoreFn, err
688-
}
689-
restoreFn = func() {
690-
_ = pty.RestoreTerminal(stdinFile.Fd(), inState)
691-
_ = pty.RestoreTerminal(stdoutFile.Fd(), outState)
692-
}
693-
694-
windowChange := listenWindowSize(ctx)
695-
go func() {
696-
for {
697-
select {
698-
case <-ctx.Done():
699-
return
700-
case <-windowChange:
701-
}
702-
width, height, err := term.GetSize(int(stdoutFile.Fd()))
703-
if err != nil {
704-
continue
705-
}
706-
_ = sshSession.WindowChange(height, width)
707-
}
708-
}()
709-
return restoreFn, nil
710-
}
711-
712714
// findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it
713715
// corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or
714716
// vscode-coder--myusername--myworkspace).
@@ -1502,87 +1504,14 @@ func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, std
15021504
return err
15031505
}
15041506

1505-
agentssh.Bicopy(ctx, conn, &cliutil.StdioConn{
1507+
agentssh.Bicopy(ctx, conn, &cliutil.ReaderWriterConn{
15061508
Reader: stdin,
15071509
Writer: stdout,
15081510
})
15091511

15101512
return nil
15111513
}
15121514

1513-
func runCoderConnectPTY(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stderr io.Writer, stack *closerStack) error {
1514-
client, err := gossh.Dial("tcp", addr, &gossh.ClientConfig{
1515-
// We've already checked the agent's address
1516-
// is within the Coder service prefix.
1517-
// #nosec
1518-
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
1519-
})
1520-
if err != nil {
1521-
return xerrors.Errorf("dial coder connect host: %w", err)
1522-
}
1523-
if err := stack.push("ssh client", client); err != nil {
1524-
return err
1525-
}
1526-
1527-
session, err := client.NewSession()
1528-
if err != nil {
1529-
return xerrors.Errorf("create ssh session: %w", err)
1530-
}
1531-
if err := stack.push("ssh session", session); err != nil {
1532-
return err
1533-
}
1534-
1535-
stdinFile, validIn := stdin.(*os.File)
1536-
stdoutFile, validOut := stdout.(*os.File)
1537-
if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) {
1538-
restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, session)
1539-
defer restorePtyFn()
1540-
if err != nil {
1541-
return xerrors.Errorf("configure pty: %w", err)
1542-
}
1543-
}
1544-
1545-
session.Stdin = stdin
1546-
session.Stdout = stdout
1547-
session.Stderr = stderr
1548-
1549-
err = session.RequestPty("xterm-256color", 80, 24, gossh.TerminalModes{})
1550-
if err != nil {
1551-
return xerrors.Errorf("request pty: %w", err)
1552-
}
1553-
1554-
err = session.Shell()
1555-
if err != nil {
1556-
return xerrors.Errorf("start shell: %w", err)
1557-
}
1558-
1559-
if validOut {
1560-
// Set initial window size.
1561-
width, height, err := term.GetSize(int(stdoutFile.Fd()))
1562-
if err == nil {
1563-
_ = session.WindowChange(height, width)
1564-
}
1565-
}
1566-
1567-
err = session.Wait()
1568-
if err != nil {
1569-
if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) {
1570-
// Clear the error since it's not useful beyond
1571-
// reporting status.
1572-
return ExitError(exitErr.ExitStatus(), nil)
1573-
}
1574-
// If the connection drops unexpectedly, we get an
1575-
// ExitMissingError but no other error details, so try to at
1576-
// least give the user a better message
1577-
if errors.Is(err, &gossh.ExitMissingError{}) {
1578-
return ExitError(255, xerrors.New("SSH connection ended unexpectedly"))
1579-
}
1580-
return xerrors.Errorf("session ended: %w", err)
1581-
}
1582-
1583-
return nil
1584-
}
1585-
15861515
func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error {
15871516
fs, ok := ctx.Value("fs").(afero.Fs)
15881517
if !ok {

cli/ssh_internal_test.go

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222

2323
"github.com/coder/coder/v2/cli/cliutil"
2424
"github.com/coder/coder/v2/codersdk"
25-
"github.com/coder/coder/v2/pty/ptytest"
2625
"github.com/coder/coder/v2/testutil"
2726
)
2827

@@ -226,37 +225,6 @@ func TestCloserStack_Timeout(t *testing.T) {
226225
testutil.TryReceive(ctx, t, closed)
227226
}
228227

229-
func TestCoderConnectPTY(t *testing.T) {
230-
t.Parallel()
231-
232-
ctx := testutil.Context(t, testutil.WaitShort)
233-
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
234-
stack := newCloserStack(ctx, logger, quartz.NewMock(t))
235-
236-
server := newSSHServer("127.0.0.1:0")
237-
ln, err := net.Listen("tcp", server.server.Addr)
238-
require.NoError(t, err)
239-
240-
go func() {
241-
_ = server.Serve(ln)
242-
}()
243-
t.Cleanup(func() {
244-
_ = server.Close()
245-
})
246-
247-
ptty := ptytest.New(t)
248-
ptyDone := make(chan struct{})
249-
go func() {
250-
err := runCoderConnectPTY(ctx, ln.Addr().String(), ptty.Output(), ptty.Input(), ptty.Output(), stack)
251-
assert.NoError(t, err)
252-
close(ptyDone)
253-
}()
254-
ptty.ExpectMatch("Connected!")
255-
// Shells on Mac, Windows, and Linux all exit shells with the "exit" command.
256-
ptty.WriteLine("exit")
257-
<-ptyDone
258-
}
259-
260228
func TestCoderConnectStdio(t *testing.T) {
261229
t.Parallel()
262230

@@ -290,7 +258,7 @@ func TestCoderConnectStdio(t *testing.T) {
290258
close(stdioDone)
291259
}()
292260

293-
conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{
261+
conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{
294262
Reader: serverOutput,
295263
Writer: clientInput,
296264
}, "", &ssh.ClientConfig{

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy