Skip to content

Commit b0ec201

Browse files
authored
Merge pull request #427 from alixander/fix-race
fix closenow race
2 parents 8d2374e + 250db1e commit b0ec201

File tree

10 files changed

+226
-162
lines changed

10 files changed

+226
-162
lines changed

accept_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"net/http"
1111
"net/http/httptest"
1212
"strings"
13+
"sync"
1314
"testing"
1415

1516
"nhooyr.io/websocket/internal/test/assert"
@@ -142,6 +143,42 @@ func TestAccept(t *testing.T) {
142143
_, err := Accept(w, r, nil)
143144
assert.Contains(t, err, `failed to hijack connection`)
144145
})
146+
t.Run("closeRace", func(t *testing.T) {
147+
t.Parallel()
148+
149+
server, _ := net.Pipe()
150+
151+
rw := bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server))
152+
newResponseWriter := func() http.ResponseWriter {
153+
return mockHijacker{
154+
ResponseWriter: httptest.NewRecorder(),
155+
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
156+
return server, rw, nil
157+
},
158+
}
159+
}
160+
w := newResponseWriter()
161+
162+
r := httptest.NewRequest("GET", "/", nil)
163+
r.Header.Set("Connection", "Upgrade")
164+
r.Header.Set("Upgrade", "websocket")
165+
r.Header.Set("Sec-WebSocket-Version", "13")
166+
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
167+
168+
c, err := Accept(w, r, nil)
169+
wg := &sync.WaitGroup{}
170+
wg.Add(2)
171+
go func() {
172+
c.Close(StatusInternalError, "the sky is falling")
173+
wg.Done()
174+
}()
175+
go func() {
176+
c.CloseNow()
177+
wg.Done()
178+
}()
179+
wg.Wait()
180+
assert.Success(t, err)
181+
})
145182
}
146183

147184
func Test_verifyClientHandshake(t *testing.T) {

close.go

Lines changed: 103 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -97,80 +97,106 @@ func CloseStatus(err error) StatusCode {
9797
//
9898
// Close will unblock all goroutines interacting with the connection once
9999
// complete.
100-
func (c *Conn) Close(code StatusCode, reason string) error {
101-
defer c.wg.Wait()
102-
return c.closeHandshake(code, reason)
100+
func (c *Conn) Close(code StatusCode, reason string) (err error) {
101+
defer errd.Wrap(&err, "failed to close WebSocket")
102+
103+
if !c.casClosing() {
104+
err = c.waitGoroutines()
105+
if err != nil {
106+
return err
107+
}
108+
return net.ErrClosed
109+
}
110+
defer func() {
111+
if errors.Is(err, net.ErrClosed) {
112+
err = nil
113+
}
114+
}()
115+
116+
err = c.closeHandshake(code, reason)
117+
118+
err2 := c.close()
119+
if err == nil && err2 != nil {
120+
err = err2
121+
}
122+
123+
err2 = c.waitGoroutines()
124+
if err == nil && err2 != nil {
125+
err = err2
126+
}
127+
128+
return err
103129
}
104130

105131
// CloseNow closes the WebSocket connection without attempting a close handshake.
106132
// Use when you do not want the overhead of the close handshake.
107133
func (c *Conn) CloseNow() (err error) {
108-
defer c.wg.Wait()
109134
defer errd.Wrap(&err, "failed to close WebSocket")
110135

111-
if c.isClosed() {
136+
if !c.casClosing() {
137+
err = c.waitGoroutines()
138+
if err != nil {
139+
return err
140+
}
112141
return net.ErrClosed
113142
}
143+
defer func() {
144+
if errors.Is(err, net.ErrClosed) {
145+
err = nil
146+
}
147+
}()
114148

115-
c.close(nil)
116-
return c.closeErr
117-
}
118-
119-
func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
120-
defer errd.Wrap(&err, "failed to close WebSocket")
121-
122-
writeErr := c.writeClose(code, reason)
123-
closeHandshakeErr := c.waitCloseHandshake()
149+
err = c.close()
124150

125-
if writeErr != nil {
126-
return writeErr
151+
err2 := c.waitGoroutines()
152+
if err == nil && err2 != nil {
153+
err = err2
127154
}
155+
return err
156+
}
128157

129-
if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) {
130-
return closeHandshakeErr
158+
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
159+
err := c.writeClose(code, reason)
160+
if err != nil {
161+
return err
131162
}
132163

164+
err = c.waitCloseHandshake()
165+
if CloseStatus(err) != code {
166+
return err
167+
}
133168
return nil
134169
}
135170

136171
func (c *Conn) writeClose(code StatusCode, reason string) error {
137-
c.closeMu.Lock()
138-
wroteClose := c.wroteClose
139-
c.wroteClose = true
140-
c.closeMu.Unlock()
141-
if wroteClose {
142-
return net.ErrClosed
143-
}
144-
145172
ce := CloseError{
146173
Code: code,
147174
Reason: reason,
148175
}
149176

150177
var p []byte
151-
var marshalErr error
178+
var err error
152179
if ce.Code != StatusNoStatusRcvd {
153-
p, marshalErr = ce.bytes()
154-
}
155-
156-
writeErr := c.writeControl(context.Background(), opClose, p)
157-
if CloseStatus(writeErr) != -1 {
158-
// Not a real error if it's due to a close frame being received.
159-
writeErr = nil
180+
p, err = ce.bytes()
181+
if err != nil {
182+
return err
183+
}
160184
}
161185

162-
// We do this after in case there was an error writing the close frame.
163-
c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
186+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
187+
defer cancel()
164188

165-
if marshalErr != nil {
166-
return marshalErr
189+
err = c.writeControl(ctx, opClose, p)
190+
// If the connection closed as we're writing we ignore the error as we might
191+
// have written the close frame, the peer responded and then someone else read it
192+
// and closed the connection.
193+
if err != nil && !errors.Is(err, net.ErrClosed) {
194+
return err
167195
}
168-
return writeErr
196+
return nil
169197
}
170198

171199
func (c *Conn) waitCloseHandshake() error {
172-
defer c.close(nil)
173-
174200
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
175201
defer cancel()
176202

@@ -180,10 +206,6 @@ func (c *Conn) waitCloseHandshake() error {
180206
}
181207
defer c.readMu.unlock()
182208

183-
if c.readCloseFrameErr != nil {
184-
return c.readCloseFrameErr
185-
}
186-
187209
for i := int64(0); i < c.msgReader.payloadLength; i++ {
188210
_, err := c.br.ReadByte()
189211
if err != nil {
@@ -206,6 +228,36 @@ func (c *Conn) waitCloseHandshake() error {
206228
}
207229
}
208230

231+
func (c *Conn) waitGoroutines() error {
232+
t := time.NewTimer(time.Second * 15)
233+
defer t.Stop()
234+
235+
select {
236+
case <-c.timeoutLoopDone:
237+
case <-t.C:
238+
return errors.New("failed to wait for timeoutLoop goroutine to exit")
239+
}
240+
241+
c.closeReadMu.Lock()
242+
closeRead := c.closeReadCtx != nil
243+
c.closeReadMu.Unlock()
244+
if closeRead {
245+
select {
246+
case <-c.closeReadDone:
247+
case <-t.C:
248+
return errors.New("failed to wait for close read goroutine to exit")
249+
}
250+
}
251+
252+
select {
253+
case <-c.closed:
254+
case <-t.C:
255+
return errors.New("failed to wait for connection to be closed")
256+
}
257+
258+
return nil
259+
}
260+
209261
func parseClosePayload(p []byte) (CloseError, error) {
210262
if len(p) == 0 {
211263
return CloseError{
@@ -276,16 +328,14 @@ func (ce CloseError) bytesErr() ([]byte, error) {
276328
return buf, nil
277329
}
278330

279-
func (c *Conn) setCloseErr(err error) {
331+
func (c *Conn) casClosing() bool {
280332
c.closeMu.Lock()
281-
c.setCloseErrLocked(err)
282-
c.closeMu.Unlock()
283-
}
284-
285-
func (c *Conn) setCloseErrLocked(err error) {
286-
if c.closeErr == nil && err != nil {
287-
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
333+
defer c.closeMu.Unlock()
334+
if !c.closing {
335+
c.closing = true
336+
return true
288337
}
338+
return false
289339
}
290340

291341
func (c *Conn) isClosed() bool {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy