diff --git a/accept_test.go b/accept_test.go index 7cb85d0f..18233b1e 100644 --- a/accept_test.go +++ b/accept_test.go @@ -10,6 +10,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "nhooyr.io/websocket/internal/test/assert" @@ -142,6 +143,42 @@ func TestAccept(t *testing.T) { _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) }) + t.Run("closeRace", func(t *testing.T) { + t.Parallel() + + server, _ := net.Pipe() + + rw := bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server)) + newResponseWriter := func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (net.Conn, *bufio.ReadWriter, error) { + return server, rw, nil + }, + } + } + w := newResponseWriter() + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + + c, err := Accept(w, r, nil) + wg := &sync.WaitGroup{} + wg.Add(2) + go func() { + c.Close(StatusInternalError, "the sky is falling") + wg.Done() + }() + go func() { + c.CloseNow() + wg.Done() + }() + wg.Wait() + assert.Success(t, err) + }) } func Test_verifyClientHandshake(t *testing.T) { diff --git a/close.go b/close.go index e925d043..625ed121 100644 --- a/close.go +++ b/close.go @@ -97,80 +97,106 @@ func CloseStatus(err error) StatusCode { // // Close will unblock all goroutines interacting with the connection once // complete. -func (c *Conn) Close(code StatusCode, reason string) error { - defer c.wg.Wait() - return c.closeHandshake(code, reason) +func (c *Conn) Close(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + if !c.casClosing() { + err = c.waitGoroutines() + if err != nil { + return err + } + return net.ErrClosed + } + defer func() { + if errors.Is(err, net.ErrClosed) { + err = nil + } + }() + + err = c.closeHandshake(code, reason) + + err2 := c.close() + if err == nil && err2 != nil { + err = err2 + } + + err2 = c.waitGoroutines() + if err == nil && err2 != nil { + err = err2 + } + + return err } // CloseNow closes the WebSocket connection without attempting a close handshake. // Use when you do not want the overhead of the close handshake. func (c *Conn) CloseNow() (err error) { - defer c.wg.Wait() defer errd.Wrap(&err, "failed to close WebSocket") - if c.isClosed() { + if !c.casClosing() { + err = c.waitGoroutines() + if err != nil { + return err + } return net.ErrClosed } + defer func() { + if errors.Is(err, net.ErrClosed) { + err = nil + } + }() - c.close(nil) - return c.closeErr -} - -func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { - defer errd.Wrap(&err, "failed to close WebSocket") - - writeErr := c.writeClose(code, reason) - closeHandshakeErr := c.waitCloseHandshake() + err = c.close() - if writeErr != nil { - return writeErr + err2 := c.waitGoroutines() + if err == nil && err2 != nil { + err = err2 } + return err +} - if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) { - return closeHandshakeErr +func (c *Conn) closeHandshake(code StatusCode, reason string) error { + err := c.writeClose(code, reason) + if err != nil { + return err } + err = c.waitCloseHandshake() + if CloseStatus(err) != code { + return err + } return nil } func (c *Conn) writeClose(code StatusCode, reason string) error { - c.closeMu.Lock() - wroteClose := c.wroteClose - c.wroteClose = true - c.closeMu.Unlock() - if wroteClose { - return net.ErrClosed - } - ce := CloseError{ Code: code, Reason: reason, } var p []byte - var marshalErr error + var err error if ce.Code != StatusNoStatusRcvd { - p, marshalErr = ce.bytes() - } - - writeErr := c.writeControl(context.Background(), opClose, p) - if CloseStatus(writeErr) != -1 { - // Not a real error if it's due to a close frame being received. - writeErr = nil + p, err = ce.bytes() + if err != nil { + return err + } } - // We do this after in case there was an error writing the close frame. - c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() - if marshalErr != nil { - return marshalErr + err = c.writeControl(ctx, opClose, p) + // If the connection closed as we're writing we ignore the error as we might + // have written the close frame, the peer responded and then someone else read it + // and closed the connection. + if err != nil && !errors.Is(err, net.ErrClosed) { + return err } - return writeErr + return nil } func (c *Conn) waitCloseHandshake() error { - defer c.close(nil) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() @@ -180,10 +206,6 @@ func (c *Conn) waitCloseHandshake() error { } defer c.readMu.unlock() - if c.readCloseFrameErr != nil { - return c.readCloseFrameErr - } - for i := int64(0); i < c.msgReader.payloadLength; i++ { _, err := c.br.ReadByte() if err != nil { @@ -206,6 +228,36 @@ func (c *Conn) waitCloseHandshake() error { } } +func (c *Conn) waitGoroutines() error { + t := time.NewTimer(time.Second * 15) + defer t.Stop() + + select { + case <-c.timeoutLoopDone: + case <-t.C: + return errors.New("failed to wait for timeoutLoop goroutine to exit") + } + + c.closeReadMu.Lock() + closeRead := c.closeReadCtx != nil + c.closeReadMu.Unlock() + if closeRead { + select { + case <-c.closeReadDone: + case <-t.C: + return errors.New("failed to wait for close read goroutine to exit") + } + } + + select { + case <-c.closed: + case <-t.C: + return errors.New("failed to wait for connection to be closed") + } + + return nil +} + func parseClosePayload(p []byte) (CloseError, error) { if len(p) == 0 { return CloseError{ @@ -276,16 +328,14 @@ func (ce CloseError) bytesErr() ([]byte, error) { return buf, nil } -func (c *Conn) setCloseErr(err error) { +func (c *Conn) casClosing() bool { c.closeMu.Lock() - c.setCloseErrLocked(err) - c.closeMu.Unlock() -} - -func (c *Conn) setCloseErrLocked(err error) { - if c.closeErr == nil && err != nil { - c.closeErr = fmt.Errorf("WebSocket closed: %w", err) + defer c.closeMu.Unlock() + if !c.closing { + c.closing = true + return true } + return false } func (c *Conn) isClosed() bool { diff --git a/conn.go b/conn.go index ef4d62ad..8690fb3b 100644 --- a/conn.go +++ b/conn.go @@ -6,7 +6,6 @@ package websocket import ( "bufio" "context" - "errors" "fmt" "io" "net" @@ -53,15 +52,15 @@ type Conn struct { br *bufio.Reader bw *bufio.Writer - readTimeout chan context.Context - writeTimeout chan context.Context + readTimeout chan context.Context + writeTimeout chan context.Context + timeoutLoopDone chan struct{} // Read state. - readMu *mu - readHeaderBuf [8]byte - readControlBuf [maxControlPayload]byte - msgReader *msgReader - readCloseFrameErr error + readMu *mu + readHeaderBuf [8]byte + readControlBuf [maxControlPayload]byte + msgReader *msgReader // Write state. msgWriter *msgWriter @@ -70,11 +69,13 @@ type Conn struct { writeHeaderBuf [8]byte writeHeader header - wg sync.WaitGroup - closed chan struct{} - closeMu sync.Mutex - closeErr error - wroteClose bool + closeReadMu sync.Mutex + closeReadCtx context.Context + closeReadDone chan struct{} + + closed chan struct{} + closeMu sync.Mutex + closing bool pingCounter int32 activePingsMu sync.Mutex @@ -103,8 +104,9 @@ func newConn(cfg connConfig) *Conn { br: cfg.br, bw: cfg.bw, - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + timeoutLoopDone: make(chan struct{}), closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), @@ -128,14 +130,10 @@ func newConn(cfg connConfig) *Conn { } runtime.SetFinalizer(c, func(c *Conn) { - c.close(errors.New("connection garbage collected")) + c.close() }) - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.timeoutLoop() - }() + go c.timeoutLoop() return c } @@ -146,35 +144,29 @@ func (c *Conn) Subprotocol() string { return c.subprotocol } -func (c *Conn) close(err error) { +func (c *Conn) close() error { c.closeMu.Lock() defer c.closeMu.Unlock() if c.isClosed() { - return - } - if err == nil { - err = c.rwc.Close() + return net.ErrClosed } - c.setCloseErrLocked(err) - - close(c.closed) runtime.SetFinalizer(c, nil) + close(c.closed) // Have to close after c.closed is closed to ensure any goroutine that wakes up // from the connection being closed also sees that c.closed is closed and returns // closeErr. - c.rwc.Close() - - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.msgWriter.close() - c.msgReader.close() - }() + err := c.rwc.Close() + // With the close of rwc, these become safe to close. + c.msgWriter.close() + c.msgReader.close() + return err } func (c *Conn) timeoutLoop() { + defer close(c.timeoutLoopDone) + readCtx := context.Background() writeCtx := context.Background() @@ -187,14 +179,10 @@ func (c *Conn) timeoutLoop() { case readCtx = <-c.readTimeout: case <-readCtx.Done(): - c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.writeError(StatusPolicyViolation, errors.New("read timed out")) - }() + c.close() + return case <-writeCtx.Done(): - c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) + c.close() return } } @@ -243,9 +231,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { case <-c.closed: return net.ErrClosed case <-ctx.Done(): - err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) - c.close(err) - return err + return fmt.Errorf("failed to wait for pong: %w", ctx.Err()) case <-pong: return nil } @@ -281,9 +267,7 @@ func (m *mu) lock(ctx context.Context) error { case <-m.c.closed: return net.ErrClosed case <-ctx.Done(): - err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) - m.c.close(err) - return err + return fmt.Errorf("failed to acquire lock: %w", ctx.Err()) case m.ch <- struct{}{}: // To make sure the connection is certainly alive. // As it's possible the send on m.ch was selected diff --git a/conn_test.go b/conn_test.go index 97b172dc..9fbe961d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -345,6 +345,9 @@ func TestConn(t *testing.T) { func TestWasm(t *testing.T) { t.Parallel() + if os.Getenv("CI") == "" { + t.Skip() + } s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r, &websocket.AcceptOptions{ @@ -360,7 +363,7 @@ func TestWasm(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".") + cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".", "-v") cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) b, err := cmd.CombinedOutput() diff --git a/mask.go b/mask.go index 5f0746dc..7bc0c8d5 100644 --- a/mask.go +++ b/mask.go @@ -16,8 +16,6 @@ import ( // to be in little endian. // // See https://github.com/golang/go/issues/31586 -// -//lint:ignore U1000 mask.go func maskGo(b []byte, key uint32) uint32 { if len(b) >= 8 { key64 := uint64(key)<<32 | uint64(key) diff --git a/mask_asm.go b/mask_asm.go index 259eec03..f9484b5b 100644 --- a/mask_asm.go +++ b/mask_asm.go @@ -3,10 +3,14 @@ package websocket func mask(b []byte, key uint32) uint32 { - if len(b) > 0 { - return maskAsm(&b[0], len(b), key) - } - return key + // TODO: Will enable in v1.9.0. + return maskGo(b, key) + /* + if len(b) > 0 { + return maskAsm(&b[0], len(b), key) + } + return key + */ } // @nhooyr: I am not confident that the amd64 or the arm64 implementations of this @@ -18,4 +22,5 @@ func mask(b []byte, key uint32) uint32 { // See https://github.com/nhooyr/websocket/pull/326#issuecomment-1771138049 // //go:noescape +//lint:ignore U1000 disabled till v1.9.0 func maskAsm(b *byte, len int, key uint32) uint32 diff --git a/netconn.go b/netconn.go index 3324014d..86f7dadb 100644 --- a/netconn.go +++ b/netconn.go @@ -94,7 +94,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { } type netConn struct { - // These must be first to be aligned on 32 bit platforms. + // These must be first to be aligned on 32 bit platforms. // https://github.com/nhooyr/websocket/pull/438 readExpired int64 writeExpired int64 diff --git a/read.go b/read.go index 81b89831..a59e71d9 100644 --- a/read.go +++ b/read.go @@ -60,14 +60,24 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { // Call CloseRead when you do not expect to read any more messages. // Since it actively reads from the connection, it will ensure that ping, pong and close // frames are responded to. This means c.Ping and c.Close will still work as expected. +// +// This function is idempotent. func (c *Conn) CloseRead(ctx context.Context) context.Context { + c.closeReadMu.Lock() + ctx2 := c.closeReadCtx + if ctx2 != nil { + c.closeReadMu.Unlock() + return ctx2 + } ctx, cancel := context.WithCancel(ctx) + c.closeReadCtx = ctx + c.closeReadDone = make(chan struct{}) + c.closeReadMu.Unlock() - c.wg.Add(1) go func() { - defer c.CloseNow() - defer c.wg.Done() + defer close(c.closeReadDone) defer cancel() + defer c.close() _, _, err := c.Reader(ctx) if err == nil { c.Close(StatusPolicyViolation, "unexpected data message") @@ -222,7 +232,6 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case <-ctx.Done(): return header{}, ctx.Err() default: - c.close(err) return header{}, err } } @@ -251,9 +260,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { case <-ctx.Done(): return n, ctx.Err() default: - err = fmt.Errorf("failed to read frame payload: %w", err) - c.close(err) - return n, err + return n, fmt.Errorf("failed to read frame payload: %w", err) } } @@ -308,9 +315,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { return nil } - defer func() { - c.readCloseFrameErr = err - }() + // opClose ce, err := parseClosePayload(b) if err != nil { @@ -320,9 +325,9 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { } err = fmt.Errorf("received close frame: %w", ce) - c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) - c.close(err) + c.readMu.unlock() + c.close() return err } @@ -336,9 +341,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro defer c.readMu.unlock() if !c.msgReader.fin { - err = errors.New("previous message not read to completion") - c.close(fmt.Errorf("failed to get reader: %w", err)) - return 0, nil, err + return 0, nil, errors.New("previous message not read to completion") } h, err := c.readLoop(ctx) @@ -411,10 +414,9 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { return n, io.EOF } if err != nil { - err = fmt.Errorf("failed to read: %w", err) - mr.c.close(err) + return n, fmt.Errorf("failed to read: %w", err) } - return n, err + return n, nil } func (mr *msgReader) read(p []byte) (int, error) { diff --git a/write.go b/write.go index 7ac7ce63..d7222f2d 100644 --- a/write.go +++ b/write.go @@ -159,7 +159,6 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer func() { if err != nil { err = fmt.Errorf("failed to write: %w", err) - mw.c.close(err) } }() @@ -242,30 +241,12 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error return nil } -// frame handles all writes to the connection. +// writeFrame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) if err != nil { return 0, err } - - // If the state says a close has already been written, we wait until - // the connection is closed and return that error. - // - // However, if the frame being written is a close, that means its the close from - // the state being set so we let it go through. - c.closeMu.Lock() - wroteClose := c.wroteClose - c.closeMu.Unlock() - if wroteClose && opcode != opClose { - c.writeFrameMu.unlock() - select { - case <-ctx.Done(): - return 0, ctx.Err() - case <-c.closed: - return 0, net.ErrClosed - } - } defer c.writeFrameMu.unlock() select { @@ -283,7 +264,6 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco err = ctx.Err() default: } - c.close(err) err = fmt.Errorf("failed to write frame: %w", err) } }() @@ -392,7 +372,5 @@ func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { } func (c *Conn) writeError(code StatusCode, err error) { - c.setCloseErr(err) c.writeClose(code, err.Error()) - c.close(nil) } diff --git a/ws_js.go b/ws_js.go index 77d0d80f..02d61f28 100644 --- a/ws_js.go +++ b/ws_js.go @@ -47,9 +47,10 @@ type Conn struct { // read limit for a message in bytes. msgReadLimit xsync.Int64 - wg sync.WaitGroup + closeReadMu sync.Mutex + closeReadCtx context.Context + closingMu sync.Mutex - isReadClosed xsync.Int64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once @@ -130,7 +131,10 @@ func (c *Conn) closeWithInternal() { // Read attempts to read a message from the connection. // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { - if c.isReadClosed.Load() == 1 { + c.closeReadMu.Lock() + closedRead := c.closeReadCtx != nil + c.closeReadMu.Unlock() + if closedRead { return 0, nil, errors.New("WebSocket connection read closed") } @@ -225,7 +229,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { // or the connection is closed. // It thus performs the full WebSocket close handshake. func (c *Conn) Close(code StatusCode, reason string) error { - defer c.wg.Wait() err := c.exportedClose(code, reason) if err != nil { return fmt.Errorf("failed to close WebSocket: %w", err) @@ -239,7 +242,6 @@ func (c *Conn) Close(code StatusCode, reason string) error { // note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close // a WebSocket without the close handshake. func (c *Conn) CloseNow() error { - defer c.wg.Wait() return c.Close(StatusGoingAway, "") } @@ -389,14 +391,19 @@ func (w *writer) Close() error { // CloseRead implements *Conn.CloseRead for wasm. func (c *Conn) CloseRead(ctx context.Context) context.Context { - c.isReadClosed.Store(1) - + c.closeReadMu.Lock() + ctx2 := c.closeReadCtx + if ctx2 != nil { + c.closeReadMu.Unlock() + return ctx2 + } ctx, cancel := context.WithCancel(ctx) - c.wg.Add(1) + c.closeReadCtx = ctx + c.closeReadMu.Unlock() + go func() { - defer c.CloseNow() - defer c.wg.Done() defer cancel() + defer c.CloseNow() _, _, err := c.read(ctx) if err != nil { c.Close(StatusPolicyViolation, "unexpected data message") pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy