diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml deleted file mode 100644 index fb83c3a9..00000000 --- a/.github/FUNDING.yml +++ /dev/null @@ -1 +0,0 @@ -github: nhooyr diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..fb0a4558 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,24 @@ +version: 2 +updates: + # Track in case we ever add dependencies. + - package-ecosystem: 'gomod' + directory: '/' + schedule: + interval: 'weekly' + commit-message: + prefix: 'chore' + + # Keep example and test/benchmark deps up-to-date. + - package-ecosystem: 'gomod' + directories: + - '/internal/examples' + - '/internal/thirdparty' + schedule: + interval: 'monthly' + commit-message: + prefix: 'chore' + labels: [] + groups: + internal-deps: + patterns: + - '*' diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 13ddbf3e..9f7aed46 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,5 +1,11 @@ name: ci -on: [push, pull_request] +on: + push: + branches: + - master + pull_request: + branches: + - master concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} cancel-in-progress: true @@ -20,17 +26,25 @@ jobs: - uses: actions/checkout@v4 - run: go version - uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod - run: ./ci/lint.sh test: runs-on: ubuntu-latest steps: + - name: Disable AppArmor + if: runner.os == 'Linux' + run: | + # Disable AppArmor for Ubuntu 23.10+. + # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md + echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: ./ci/test.sh - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: coverage.html path: ./ci/out/coverage.html diff --git a/.github/workflows/daily.yml b/.github/workflows/daily.yml index 2ba9ce34..0eac94cc 100644 --- a/.github/workflows/daily.yml +++ b/.github/workflows/daily.yml @@ -19,12 +19,18 @@ jobs: test: runs-on: ubuntu-latest steps: + - name: Disable AppArmor + if: runner.os == 'Linux' + run: | + # Disable AppArmor for Ubuntu 23.10+. + # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md + echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: AUTOBAHN=1 ./ci/test.sh - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: coverage.html path: ./ci/out/coverage.html @@ -41,6 +47,12 @@ jobs: test-dev: runs-on: ubuntu-latest steps: + - name: Disable AppArmor + if: runner.os == 'Linux' + run: | + # Disable AppArmor for Ubuntu 23.10+. + # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md + echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns - uses: actions/checkout@v4 with: ref: dev @@ -48,7 +60,7 @@ jobs: with: go-version-file: ./go.mod - run: AUTOBAHN=1 ./ci/test.sh - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: coverage-dev.html path: ./ci/out/coverage.html diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml new file mode 100644 index 00000000..6ea76ab6 --- /dev/null +++ b/.github/workflows/static.yml @@ -0,0 +1,52 @@ +name: static + +on: + push: + branches: ['master'] + workflow_dispatch: + +# Set permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages. +permissions: + contents: read + pages: write + id-token: write + +concurrency: + group: pages + cancel-in-progress: true + +jobs: + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + steps: + - name: Disable AppArmor + if: runner.os == 'Linux' + run: | + # Disable AppArmor for Ubuntu 23.10+. + # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md + echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Pages + uses: actions/configure-pages@v5 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + - name: Generate coverage and badge + run: | + ./ci/test.sh + mkdir -p ./ci/out/static + cp ./ci/out/coverage.html ./ci/out/static/coverage.html + percent=$(go tool cover -func ./ci/out/coverage.prof | tail -n1 | awk '{print $3}' | tr -d '%') + wget -O ./ci/out/static/coverage.svg "https://img.shields.io/badge/coverage-${percent}%25-success" + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 + with: + path: ./ci/out/static/ + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/README.md b/README.md index c74b79dd..80d2b3cc 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # websocket [![Go Reference](https://pkg.go.dev/badge/github.com/coder/websocket.svg)](https://pkg.go.dev/github.com/coder/websocket) -[![Go Coverage](https://img.shields.io/badge/coverage-91%25-success)](https://github.com/coder/websocket/coverage.html) +[![Go Coverage](https://coder.github.io/websocket/coverage.svg)](https://coder.github.io/websocket/coverage.html) websocket is a minimal and idiomatic WebSocket library for Go. @@ -63,7 +63,9 @@ http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { } defer c.CloseNow() - ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) + // Set the context as needed. Use of r.Context() is not recommended + // to avoid surprising behavior (see http.Hijacker). + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() var v interface{} diff --git a/accept.go b/accept.go index f672a730..f45fdd0b 100644 --- a/accept.go +++ b/accept.go @@ -5,6 +5,7 @@ package websocket import ( "bytes" + "context" "crypto/sha1" "encoding/base64" "errors" @@ -14,7 +15,7 @@ import ( "net/http" "net/textproto" "net/url" - "path/filepath" + "path" "strings" "github.com/coder/websocket/internal/errd" @@ -41,8 +42,8 @@ type AcceptOptions struct { // One would set this field to []string{"example.com"} to authorize example.com to connect. // // Each pattern is matched case insensitively against the request origin host - // with filepath.Match. - // See https://golang.org/pkg/path/filepath/#Match + // with path.Match. + // See https://golang.org/pkg/path/#Match // // Please ensure you understand the ramifications of enabling this. // If used incorrectly your WebSocket server will be open to CSRF attacks. @@ -62,6 +63,22 @@ type AcceptOptions struct { // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int + + // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. + // + // The payload contains the application data of the ping frame. + // If the callback returns false, the subsequent pong frame will not be sent. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + OnPingReceived func(ctx context.Context, payload []byte) bool + + // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. + // + // The payload contains the application data of the pong frame. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + // + // Unlike OnPingReceived, this callback does not return a value because a pong frame + // is a response to a ping and does not trigger any further frame transmission. + OnPongReceived func(ctx context.Context, payload []byte) } func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { @@ -79,6 +96,9 @@ func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. // // Accept will write a response to w on all errors. +// +// Note that using the http.Request Context after Accept returns may lead to +// unexpected behavior (see http.Hijacker). func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { return accept(w, r, opts) } @@ -96,7 +116,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con if !opts.InsecureSkipVerify { err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { - if errors.Is(err, filepath.ErrBadPattern) { + if errors.Is(err, path.ErrBadPattern) { log.Printf("websocket: %v", err) err = errors.New(http.StatusText(http.StatusForbidden)) } @@ -105,7 +125,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con } } - hj, ok := w.(http.Hijacker) + hj, ok := hijacker(w) if !ok { err = errors.New("http.ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) @@ -153,6 +173,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con client: false, copts: copts, flateThreshold: opts.CompressionThreshold, + onPingReceived: opts.OnPingReceived, + onPongReceived: opts.OnPongReceived, br: brw.Reader, bw: brw.Writer, @@ -221,7 +243,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { for _, hostPattern := range originHosts { matched, err := match(hostPattern, u.Host) if err != nil { - return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) + return fmt.Errorf("failed to parse path pattern %q: %w", hostPattern, err) } if matched { return nil @@ -234,7 +256,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { } func match(pattern, s string) (bool, error) { - return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) + return path.Match(strings.ToLower(pattern), strings.ToLower(s)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { diff --git a/accept_test.go b/accept_test.go index 4f799126..3b45ac5c 100644 --- a/accept_test.go +++ b/accept_test.go @@ -143,6 +143,33 @@ func TestAccept(t *testing.T) { _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) }) + + t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) { + t.Parallel() + + rr := httptest.NewRecorder() + w := mockUnwrapper{ + ResponseWriter: rr, + unwrap: func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: rr, + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, errors.New("haha") + }, + } + }, + } + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + + _, err := Accept(w, r, nil) + assert.Contains(t, err, "failed to hijack connection") + }) + t.Run("closeRace", func(t *testing.T) { t.Parallel() @@ -534,3 +561,14 @@ var _ http.Hijacker = mockHijacker{} func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return mj.hijack() } + +type mockUnwrapper struct { + http.ResponseWriter + unwrap func() http.ResponseWriter +} + +var _ rwUnwrapper = mockUnwrapper{} + +func (mu mockUnwrapper) Unwrap() http.ResponseWriter { + return mu.unwrap() +} diff --git a/autobahn_test.go b/autobahn_test.go index b1b3a7e9..cd0cc9bb 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -92,7 +92,7 @@ func TestAutobahn(t *testing.T) { } }) - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) + c, _, err := websocket.Dial(ctx, wstestURL+"/updateReports?agent=main", nil) assert.Success(t, err) c.Close(websocket.StatusNormalClosure, "") diff --git a/ci/fmt.sh b/ci/fmt.sh index 31d0c15d..e319a1e4 100755 --- a/ci/fmt.sh +++ b/ci/fmt.sh @@ -2,22 +2,24 @@ set -eu cd -- "$(dirname "$0")/.." +# Pin golang.org/x/tools, the go.mod of v0.25.0 is incompatible with Go 1.19. +X_TOOLS_VERSION=v0.24.0 + go mod tidy (cd ./internal/thirdparty && go mod tidy) (cd ./internal/examples && go mod tidy) gofmt -w -s . -go run golang.org/x/tools/cmd/goimports@latest -w "-local=$(go list -m)" . +go run golang.org/x/tools/cmd/goimports@${X_TOOLS_VERSION} -w "-local=$(go list -m)" . -npx prettier@3.0.3 \ - --write \ +git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html" | xargs npx prettier@3.3.3 \ + --check \ --log-level=warn \ --print-width=90 \ --no-semi \ --single-quote \ - --arrow-parens=avoid \ - $(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") + --arrow-parens=avoid -go run golang.org/x/tools/cmd/stringer@latest -type=opcode,MessageType,StatusCode -output=stringer.go +go run golang.org/x/tools/cmd/stringer@${X_TOOLS_VERSION} -type=opcode,MessageType,StatusCode -output=stringer.go if [ "${CI-}" ]; then git diff --exit-code diff --git a/ci/lint.sh b/ci/lint.sh index 3cf8eee4..cf9d1abd 100755 --- a/ci/lint.sh +++ b/ci/lint.sh @@ -1,11 +1,12 @@ #!/bin/sh +set -x set -eu cd -- "$(dirname "$0")/.." go vet ./... GOOS=js GOARCH=wasm go vet ./... -go install honnef.co/go/tools/cmd/staticcheck@latest +go install honnef.co/go/tools/cmd/staticcheck@v0.4.7 staticcheck ./... GOOS=js GOARCH=wasm staticcheck ./... @@ -15,7 +16,7 @@ govulncheck() { cat "$tmpf" fi } -go install golang.org/x/vuln/cmd/govulncheck@latest +go install golang.org/x/vuln/cmd/govulncheck@v1.1.1 govulncheck ./... GOOS=js GOARCH=wasm govulncheck ./... diff --git a/ci/test.sh b/ci/test.sh index a3007614..cc3c22d7 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -24,7 +24,7 @@ cd -- "$(dirname "$0")/.." ) -go install github.com/agnivade/wasmbrowsertest@latest +go install github.com/agnivade/wasmbrowsertest@8be019f6c6dceae821467b4c589eb195c2b761ce go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./... sed -i.bak '/stringer\.go/d' ci/out/coverage.prof sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof diff --git a/close.go b/close.go index ff2e878a..f94951dc 100644 --- a/close.go +++ b/close.go @@ -100,7 +100,7 @@ func CloseStatus(err error) StatusCode { func (c *Conn) Close(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") - if !c.casClosing() { + if c.casClosing() { err = c.waitGoroutines() if err != nil { return err @@ -133,7 +133,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) { func (c *Conn) CloseNow() (err error) { defer errd.Wrap(&err, "failed to immediately close WebSocket") - if !c.casClosing() { + if c.casClosing() { err = c.waitGoroutines() if err != nil { return err @@ -329,13 +329,7 @@ func (ce CloseError) bytesErr() ([]byte, error) { } func (c *Conn) casClosing() bool { - c.closeMu.Lock() - defer c.closeMu.Unlock() - if !c.closing { - c.closing = true - return true - } - return false + return c.closing.Swap(true) } func (c *Conn) isClosed() bool { diff --git a/conn.go b/conn.go index 8690fb3b..42fe89fe 100644 --- a/conn.go +++ b/conn.go @@ -69,17 +69,25 @@ type Conn struct { writeHeaderBuf [8]byte writeHeader header + // Close handshake state. + closeStateMu sync.RWMutex + closeReceivedErr error + closeSentErr error + + // CloseRead state. closeReadMu sync.Mutex closeReadCtx context.Context closeReadDone chan struct{} + closing atomic.Bool + closeMu sync.Mutex // Protects following. closed chan struct{} - closeMu sync.Mutex - closing bool - pingCounter int32 - activePingsMu sync.Mutex - activePings map[string]chan<- struct{} + pingCounter atomic.Int64 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} + onPingReceived func(context.Context, []byte) bool + onPongReceived func(context.Context, []byte) } type connConfig struct { @@ -88,6 +96,8 @@ type connConfig struct { client bool copts *compressionOptions flateThreshold int + onPingReceived func(context.Context, []byte) bool + onPongReceived func(context.Context, []byte) br *bufio.Reader bw *bufio.Writer @@ -108,8 +118,10 @@ func newConn(cfg connConfig) *Conn { writeTimeout: make(chan context.Context), timeoutLoopDone: make(chan struct{}), - closed: make(chan struct{}), - activePings: make(map[string]chan<- struct{}), + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + onPingReceived: cfg.onPingReceived, + onPongReceived: cfg.onPongReceived, } c.readMu = newMu(c) @@ -200,9 +212,9 @@ func (c *Conn) flate() bool { // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { - p := atomic.AddInt32(&c.pingCounter, 1) + p := c.pingCounter.Add(1) - err := c.ping(ctx, strconv.Itoa(int(p))) + err := c.ping(ctx, strconv.FormatInt(p, 10)) if err != nil { return fmt.Errorf("failed to ping: %w", err) } diff --git a/conn_test.go b/conn_test.go index be7c9983..45bb75be 100644 --- a/conn_test.go +++ b/conn_test.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/http/httptest" "os" @@ -96,6 +97,85 @@ func TestConn(t *testing.T) { assert.Contains(t, err, "failed to wait for pong") }) + t.Run("pingReceivedPongReceived", func(t *testing.T) { + var pingReceived1, pongReceived1 bool + var pingReceived2, pongReceived2 bool + tt, c1, c2 := newConnTest(t, + &websocket.DialOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived1 = true + return true + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived1 = true + }, + }, &websocket.AcceptOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived2 = true + return true + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived2 = true + }, + }, + ) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + + ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) + defer cancel() + + err := c1.Ping(ctx) + assert.Success(t, err) + + c1.CloseNow() + c2.CloseNow() + + assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2) + assert.Equal(t, "only one side receives the pong", false, pongReceived1 && pongReceived2) + assert.Equal(t, "ping and pong received", true, (pingReceived1 && pongReceived2) || (pingReceived2 && pongReceived1)) + }) + + t.Run("pingReceivedPongNotReceived", func(t *testing.T) { + var pingReceived1, pongReceived1 bool + var pingReceived2, pongReceived2 bool + tt, c1, c2 := newConnTest(t, + &websocket.DialOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived1 = true + return false + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived1 = true + }, + }, &websocket.AcceptOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived2 = true + return false + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived2 = true + }, + }, + ) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + + ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) + defer cancel() + + err := c1.Ping(ctx) + assert.Contains(t, err, "failed to wait for pong") + + c1.CloseNow() + c2.CloseNow() + + assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2) + assert.Equal(t, "ping received and pong not received", true, (pingReceived1 && !pongReceived2) || (pingReceived2 && !pongReceived1)) + }) + t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) @@ -364,7 +444,7 @@ func TestWasm(t *testing.T) { defer cancel() cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".", "-v") - cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) + cmd.Env = append(cleanEnv(os.Environ()), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) b, err := cmd.CombinedOutput() if err != nil { @@ -372,6 +452,18 @@ func TestWasm(t *testing.T) { } } +func cleanEnv(env []string) (out []string) { + for _, e := range env { + // Filter out GITHUB envs and anything with token in it, + // especially GITHUB_TOKEN in CI as it breaks TestWasm. + if strings.HasPrefix(e, "GITHUB") || strings.Contains(e, "TOKEN") { + continue + } + out = append(out, e) + } + return out +} + func assertCloseStatus(exp websocket.StatusCode, err error) error { if websocket.CloseStatus(err) == -1 { return fmt.Errorf("expected websocket.CloseError: %T %v", err, err) @@ -448,7 +540,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) { } func BenchmarkConn(b *testing.B) { - var benchCases = []struct { + benchCases := []struct { name string mode websocket.CompressionMode }{ @@ -613,3 +705,149 @@ func TestConcurrentClosePing(t *testing.T) { }() } } + +func TestConnClosePropagation(t *testing.T) { + t.Parallel() + + want := []byte("hello") + keepWriting := func(c *websocket.Conn) <-chan error { + return xsync.Go(func() error { + for { + err := c.Write(context.Background(), websocket.MessageText, want) + if err != nil { + return err + } + } + }) + } + keepReading := func(c *websocket.Conn) <-chan error { + return xsync.Go(func() error { + for { + _, got, err := c.Read(context.Background()) + if err != nil { + return err + } + if !bytes.Equal(want, got) { + return fmt.Errorf("unexpected message: want %q, got %q", want, got) + } + } + }) + } + checkReadErr := func(t *testing.T, err error) { + // Check read error (output depends on when read is called in relation to connection closure). + var ce websocket.CloseError + if errors.As(err, &ce) { + assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code) + } else { + assert.ErrorIs(t, net.ErrClosed, err) + } + } + checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) { + for _, c := range conn { + // Check write error. + err := c.Write(context.Background(), websocket.MessageText, want) + assert.ErrorIs(t, net.ErrClosed, err) + + _, _, err = c.Read(context.Background()) + checkReadErr(t, err) + } + } + + t.Run("CloseOtherSideDuringWrite", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = this.CloseRead(tt.ctx) + thisWriteErr := keepWriting(this) + + _, got, err := other.Read(tt.ctx) + assert.Success(t, err) + assert.Equal(t, "msg", want, got) + + err = other.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisWriteErr: + assert.ErrorIs(t, net.ErrClosed, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseThisSideDuringWrite", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = this.CloseRead(tt.ctx) + thisWriteErr := keepWriting(this) + otherReadErr := keepReading(other) + + err := this.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisWriteErr: + assert.ErrorIs(t, net.ErrClosed, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + select { + case err := <-otherReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseOtherSideDuringRead", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = other.CloseRead(tt.ctx) + errs := keepReading(this) + + err := other.Write(tt.ctx, websocket.MessageText, want) + assert.Success(t, err) + + err = other.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-errs: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseThisSideDuringRead", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + thisReadErr := keepReading(this) + otherReadErr := keepReading(other) + + err := other.Write(tt.ctx, websocket.MessageText, want) + assert.Success(t, err) + + err = this.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + select { + case err := <-otherReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) +} diff --git a/dial.go b/dial.go index ad61a35d..0b11ecbb 100644 --- a/dial.go +++ b/dial.go @@ -48,6 +48,22 @@ type DialOptions struct { // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int + + // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. + // + // The payload contains the application data of the ping frame. + // If the callback returns false, the subsequent pong frame will not be sent. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + OnPingReceived func(ctx context.Context, payload []byte) bool + + // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. + // + // The payload contains the application data of the pong frame. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + // + // Unlike OnPingReceived, this callback does not return a value because a pong frame + // is a response to a ping and does not trigger any further frame transmission. + OnPongReceived func(ctx context.Context, payload []byte) } func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { @@ -163,6 +179,8 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( client: true, copts: copts, flateThreshold: opts.CompressionThreshold, + onPingReceived: opts.OnPingReceived, + onPongReceived: opts.OnPongReceived, br: getBufioReader(rwc), bw: getBufioWriter(rwc), }), resp, nil diff --git a/hijack.go b/hijack.go new file mode 100644 index 00000000..9cce45ca --- /dev/null +++ b/hijack.go @@ -0,0 +1,33 @@ +//go:build !js + +package websocket + +import ( + "net/http" +) + +type rwUnwrapper interface { + Unwrap() http.ResponseWriter +} + +// hijacker returns the Hijacker interface of the http.ResponseWriter. +// It follows the Unwrap method of the http.ResponseWriter if available, +// matching the behavior of http.ResponseController. If the Hijacker +// interface is not found, it returns false. +// +// Since the http.ResponseController is not available in Go 1.19, and +// does not support checking the presence of the Hijacker interface, +// this function is used to provide a consistent way to check for the +// Hijacker interface across Go versions. +func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) { + for { + switch t := rw.(type) { + case http.Hijacker: + return t, true + case rwUnwrapper: + rw = t.Unwrap() + default: + return nil, false + } + } +} diff --git a/hijack_go120_test.go b/hijack_go120_test.go new file mode 100644 index 00000000..0f0673a9 --- /dev/null +++ b/hijack_go120_test.go @@ -0,0 +1,38 @@ +//go:build !js && go1.20 + +package websocket + +import ( + "bufio" + "errors" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/coder/websocket/internal/test/assert" +) + +func Test_hijackerHTTPResponseControllerCompatibility(t *testing.T) { + t.Parallel() + + rr := httptest.NewRecorder() + w := mockUnwrapper{ + ResponseWriter: rr, + unwrap: func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: rr, + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, errors.New("haha") + }, + } + }, + } + + _, _, err := http.NewResponseController(w).Hijack() + assert.Contains(t, err, "haha") + hj, ok := hijacker(w) + assert.Equal(t, "hijacker found", ok, true) + _, _, err = hj.Hijack() + assert.Contains(t, err, "haha") +} diff --git a/internal/bpool/bpool.go b/internal/bpool/bpool.go index aa826fba..12cf577a 100644 --- a/internal/bpool/bpool.go +++ b/internal/bpool/bpool.go @@ -5,15 +5,16 @@ import ( "sync" ) -var bpool sync.Pool +var bpool = sync.Pool{ + New: func() any { + return &bytes.Buffer{} + }, +} // Get returns a buffer from the pool or creates a new one if // the pool is empty. func Get() *bytes.Buffer { b := bpool.Get() - if b == nil { - return &bytes.Buffer{} - } return b.(*bytes.Buffer) } diff --git a/internal/examples/chat/chat.go b/internal/examples/chat/chat.go index 3cb1e021..29f304b7 100644 --- a/internal/examples/chat/chat.go +++ b/internal/examples/chat/chat.go @@ -70,7 +70,7 @@ func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // subscribeHandler accepts the WebSocket connection and then subscribes // it to all future messages. func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { - err := cs.subscribe(r.Context(), w, r) + err := cs.subscribe(w, r) if errors.Is(err, context.Canceled) { return } @@ -111,7 +111,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { // // It uses CloseRead to keep reading from the connection to process control // messages and cancel the context if the connection drops. -func (cs *chatServer) subscribe(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (cs *chatServer) subscribe(w http.ResponseWriter, r *http.Request) error { var mu sync.Mutex var c *websocket.Conn var closed bool @@ -142,7 +142,7 @@ func (cs *chatServer) subscribe(ctx context.Context, w http.ResponseWriter, r *h mu.Unlock() defer c.CloseNow() - ctx = c.CloseRead(ctx) + ctx := c.CloseRead(context.Background()) for { select { diff --git a/internal/examples/chat/chat_test.go b/internal/examples/chat/chat_test.go index 8eb72051..dcada0b2 100644 --- a/internal/examples/chat/chat_test.go +++ b/internal/examples/chat/chat_test.go @@ -52,7 +52,7 @@ func Test_chatServer(t *testing.T) { // 10 clients are started that send 128 different // messages of max 128 bytes concurrently. // - // The test verifies that every message is seen by ever client + // The test verifies that every message is seen by every client // and no errors occur anywhere. t.Run("concurrency", func(t *testing.T) { t.Parallel() diff --git a/internal/examples/echo/server.go b/internal/examples/echo/server.go index a44d20b5..37e2f2c4 100644 --- a/internal/examples/echo/server.go +++ b/internal/examples/echo/server.go @@ -37,7 +37,7 @@ func (s echoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) for { - err = echo(r.Context(), c, l) + err = echo(c, l) if websocket.CloseStatus(err) == websocket.StatusNormalClosure { return } @@ -51,8 +51,8 @@ func (s echoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // echo reads from the WebSocket connection and then writes // the received message back to it. // The entire function has 10s to complete. -func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { - ctx, cancel := context.WithTimeout(ctx, time.Second*10) +func echo(c *websocket.Conn, l *rate.Limiter) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() err := l.Wait(ctx) diff --git a/internal/examples/go.mod b/internal/examples/go.mod index 4f7a8a70..2aa1ee02 100644 --- a/internal/examples/go.mod +++ b/internal/examples/go.mod @@ -6,5 +6,5 @@ replace github.com/coder/websocket => ../.. require ( github.com/coder/websocket v0.0.0-00010101000000-000000000000 - golang.org/x/time v0.3.0 + golang.org/x/time v0.7.0 ) diff --git a/internal/examples/go.sum b/internal/examples/go.sum index f8a07e82..60aa8f9a 100644 --- a/internal/examples/go.sum +++ b/internal/examples/go.sum @@ -1,2 +1,2 @@ -golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= -golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= +golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= diff --git a/internal/thirdparty/go.mod b/internal/thirdparty/go.mod index d946ffae..e060ce67 100644 --- a/internal/thirdparty/go.mod +++ b/internal/thirdparty/go.mod @@ -6,38 +6,40 @@ replace github.com/coder/websocket => ../.. require ( github.com/coder/websocket v0.0.0-00010101000000-000000000000 - github.com/gin-gonic/gin v1.9.1 - github.com/gobwas/ws v1.3.0 - github.com/gorilla/websocket v1.5.0 - github.com/lesismal/nbio v1.3.18 + github.com/gin-gonic/gin v1.10.0 + github.com/gobwas/ws v1.4.0 + github.com/gorilla/websocket v1.5.3 + github.com/lesismal/nbio v1.5.12 ) require ( - github.com/bytedance/sonic v1.9.1 // indirect - github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect - github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/bytedance/sonic v1.11.6 // indirect + github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/cloudwego/base64x v0.1.4 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/cpuid/v2 v2.2.4 // indirect - github.com/leodido/go-urn v1.2.4 // indirect - github.com/lesismal/llib v1.1.12 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect + github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/lesismal/llib v1.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - github.com/ugorji/go/codec v1.2.11 // indirect - golang.org/x/arch v0.3.0 // indirect - golang.org/x/crypto v0.9.0 // indirect - golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.17.0 // indirect - golang.org/x/text v0.9.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect + github.com/ugorji/go/codec v1.2.12 // indirect + golang.org/x/arch v0.8.0 // indirect + golang.org/x/crypto v0.23.0 // indirect + golang.org/x/net v0.25.0 // indirect + golang.org/x/sys v0.20.0 // indirect + golang.org/x/text v0.15.0 // indirect + google.golang.org/protobuf v1.34.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/internal/thirdparty/go.sum b/internal/thirdparty/go.sum index 1f542103..2352ac75 100644 --- a/internal/thirdparty/go.sum +++ b/internal/thirdparty/go.sum @@ -1,129 +1,107 @@ -github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= -github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= -github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= -github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= +github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= +github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= -github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= -github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= +github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= -github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= -github.com/gobwas/ws v1.3.0 h1:sbeU3Y4Qzlb+MOzIe6mQGf7QR4Hkv6ZD0qhGkBFL2O0= -github.com/gobwas/ws v1.3.0/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= +github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= +github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= -github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= -github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= -github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= -github.com/lesismal/llib v1.1.12 h1:KJFB8bL02V+QGIvILEw/w7s6bKj9Ps9Px97MZP2EOk0= -github.com/lesismal/llib v1.1.12/go.mod h1:70tFXXe7P1FZ02AU9l8LgSOK7d7sRrpnkUr3rd3gKSg= -github.com/lesismal/nbio v1.3.18 h1:kmJZlxjQpVfuCPYcXdv0Biv9LHVViJZet5K99Xs3RAs= -github.com/lesismal/nbio v1.3.18/go.mod h1:KWlouFT5cgDdW5sMX8RsHASUMGniea9X0XIellZ0B38= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/lesismal/llib v1.1.13 h1:+w1+t0PykXpj2dXQck0+p6vdC9/mnbEXHgUy/HXDGfE= +github.com/lesismal/llib v1.1.13/go.mod h1:70tFXXe7P1FZ02AU9l8LgSOK7d7sRrpnkUr3rd3gKSg= +github.com/lesismal/nbio v1.5.12 h1:YcUjjmOvmKEANs6Oo175JogXvHy8CuE7i6ccjM2/tv4= +github.com/lesismal/nbio v1.5.12/go.mod h1:QsxE0fKFe1PioyjuHVDn2y8ktYK7xv9MFbpkoRFj8vI= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= -github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= -github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= -github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= -github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= -golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= +golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/internal/xsync/int64.go b/internal/xsync/int64.go deleted file mode 100644 index a0c40204..00000000 --- a/internal/xsync/int64.go +++ /dev/null @@ -1,23 +0,0 @@ -package xsync - -import ( - "sync/atomic" -) - -// Int64 represents an atomic int64. -type Int64 struct { - // We do not use atomic.Load/StoreInt64 since it does not - // work on 32 bit computers but we need 64 bit integers. - i atomic.Value -} - -// Load loads the int64. -func (v *Int64) Load() int64 { - i, _ := v.i.Load().(int64) - return i -} - -// Store stores the int64. -func (v *Int64) Store(i int64) { - v.i.Store(i) -} diff --git a/netconn.go b/netconn.go index 86f7dadb..b118e4d3 100644 --- a/netconn.go +++ b/netconn.go @@ -68,7 +68,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { defer nc.writeMu.unlock() // Prevents future writes from writing until the deadline is reset. - atomic.StoreInt64(&nc.writeExpired, 1) + nc.writeExpired.Store(1) }) if !nc.writeTimer.Stop() { <-nc.writeTimer.C @@ -84,7 +84,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { defer nc.readMu.unlock() // Prevents future reads from reading until the deadline is reset. - atomic.StoreInt64(&nc.readExpired, 1) + nc.readExpired.Store(1) }) if !nc.readTimer.Stop() { <-nc.readTimer.C @@ -94,25 +94,22 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { } type netConn struct { - // These must be first to be aligned on 32 bit platforms. - // https://github.com/nhooyr/websocket/pull/438 - readExpired int64 - writeExpired int64 - c *Conn msgType MessageType - writeTimer *time.Timer - writeMu *mu - writeCtx context.Context - writeCancel context.CancelFunc - - readTimer *time.Timer - readMu *mu - readCtx context.Context - readCancel context.CancelFunc - readEOFed bool - reader io.Reader + writeTimer *time.Timer + writeMu *mu + writeExpired atomic.Int64 + writeCtx context.Context + writeCancel context.CancelFunc + + readTimer *time.Timer + readMu *mu + readExpired atomic.Int64 + readCtx context.Context + readCancel context.CancelFunc + readEOFed bool + reader io.Reader } var _ net.Conn = &netConn{} @@ -129,7 +126,7 @@ func (nc *netConn) Write(p []byte) (int, error) { nc.writeMu.forceLock() defer nc.writeMu.unlock() - if atomic.LoadInt64(&nc.writeExpired) == 1 { + if nc.writeExpired.Load() == 1 { return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded) } @@ -157,7 +154,7 @@ func (nc *netConn) Read(p []byte) (int, error) { } func (nc *netConn) read(p []byte) (int, error) { - if atomic.LoadInt64(&nc.readExpired) == 1 { + if nc.readExpired.Load() == 1 { return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded) } @@ -209,7 +206,7 @@ func (nc *netConn) SetDeadline(t time.Time) error { } func (nc *netConn) SetWriteDeadline(t time.Time) error { - atomic.StoreInt64(&nc.writeExpired, 0) + nc.writeExpired.Store(0) if t.IsZero() { nc.writeTimer.Stop() } else { @@ -223,7 +220,7 @@ func (nc *netConn) SetWriteDeadline(t time.Time) error { } func (nc *netConn) SetReadDeadline(t time.Time) error { - atomic.StoreInt64(&nc.readExpired, 0) + nc.readExpired.Store(0) if t.IsZero() { nc.readTimer.Stop() } else { diff --git a/read.go b/read.go index 1b9404b8..2db22435 100644 --- a/read.go +++ b/read.go @@ -11,11 +11,11 @@ import ( "io" "net" "strings" + "sync/atomic" "time" "github.com/coder/websocket/internal/errd" "github.com/coder/websocket/internal/util" - "github.com/coder/websocket/internal/xsync" ) // Reader reads from the connection until there is a WebSocket @@ -217,57 +217,68 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { } } -func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { +// prepareRead sets the readTimeout context and returns a done function +// to be called after the read is done. It also returns an error if the +// connection is closed. The reference to the error is used to assign +// an error depending on if the connection closed or the context timed +// out during use. Typically the referenced error is a named return +// variable of the function calling this method. +func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) { select { case <-c.closed: - return header{}, net.ErrClosed + return nil, net.ErrClosed case c.readTimeout <- ctx: } - h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) - if err != nil { + done := func() { select { case <-c.closed: - return header{}, net.ErrClosed - case <-ctx.Done(): - return header{}, ctx.Err() - default: - return header{}, err + if *err != nil { + *err = net.ErrClosed + } + case c.readTimeout <- context.Background(): + } + if *err != nil && ctx.Err() != nil { + *err = ctx.Err() } } - select { - case <-c.closed: - return header{}, net.ErrClosed - case c.readTimeout <- context.Background(): + c.closeStateMu.Lock() + closeReceivedErr := c.closeReceivedErr + c.closeStateMu.Unlock() + if closeReceivedErr != nil { + defer done() + return nil, closeReceivedErr } - return h, nil + return done, nil } -func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { - select { - case <-c.closed: - return 0, net.ErrClosed - case c.readTimeout <- ctx: +func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { + readDone, err := c.prepareRead(ctx, &err) + if err != nil { + return header{}, err } + defer readDone() - n, err := io.ReadFull(c.br, p) + h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) if err != nil { - select { - case <-c.closed: - return n, net.ErrClosed - case <-ctx.Done(): - return n, ctx.Err() - default: - return n, fmt.Errorf("failed to read frame payload: %w", err) - } + return header{}, err } - select { - case <-c.closed: - return n, net.ErrClosed - case c.readTimeout <- context.Background(): + return h, nil +} + +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { + readDone, err := c.prepareRead(ctx, &err) + if err != nil { + return 0, err + } + defer readDone() + + n, err := io.ReadFull(c.br, p) + if err != nil { + return n, fmt.Errorf("failed to read frame payload: %w", err) } return n, err @@ -301,8 +312,16 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { switch h.opcode { case opPing: + if c.onPingReceived != nil { + if !c.onPingReceived(ctx, b) { + return nil + } + } return c.writeControl(ctx, opPong, b) case opPong: + if c.onPongReceived != nil { + c.onPongReceived(ctx, b) + } c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] c.activePingsMu.Unlock() @@ -325,9 +344,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { } err = fmt.Errorf("received close frame: %w", ce) - c.writeClose(ce.Code, ce.Reason) - c.readMu.unlock() - c.close() + c.closeStateMu.Lock() + c.closeReceivedErr = err + closeSent := c.closeSentErr != nil + c.closeStateMu.Unlock() + + // Only unlock readMu if this connection is being closed becaue + // c.close will try to acquire the readMu lock. We unlock for + // writeClose as well because it may also call c.close. + if !closeSent { + c.readMu.unlock() + _ = c.writeClose(ce.Code, ce.Reason) + } + if !c.casClosing() { + c.readMu.unlock() + _ = c.close() + } return err } @@ -465,7 +497,7 @@ func (mr *msgReader) read(p []byte) (int, error) { type limitReader struct { c *Conn r io.Reader - limit xsync.Int64 + limit atomic.Int64 n int64 } diff --git a/write.go b/write.go index e294a680..7324de74 100644 --- a/write.go +++ b/write.go @@ -5,6 +5,7 @@ package websocket import ( "bufio" + "compress/flate" "context" "crypto/rand" "encoding/binary" @@ -14,8 +15,6 @@ import ( "net" "time" - "compress/flate" - "github.com/coder/websocket/internal/errd" "github.com/coder/websocket/internal/util" ) @@ -249,22 +248,36 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } defer c.writeFrameMu.unlock() + defer func() { + if c.isClosed() && opcode == opClose { + err = nil + } + if err != nil { + if ctx.Err() != nil { + err = ctx.Err() + } else if c.isClosed() { + err = net.ErrClosed + } + err = fmt.Errorf("failed to write frame: %w", err) + } + }() + + c.closeStateMu.Lock() + closeSentErr := c.closeSentErr + c.closeStateMu.Unlock() + if closeSentErr != nil { + return 0, net.ErrClosed + } + select { case <-c.closed: return 0, net.ErrClosed case c.writeTimeout <- ctx: } - defer func() { - if err != nil { - select { - case <-c.closed: - err = net.ErrClosed - case <-ctx.Done(): - err = ctx.Err() - default: - } - err = fmt.Errorf("failed to write frame: %w", err) + select { + case <-c.closed: + case c.writeTimeout <- context.Background(): } }() @@ -303,13 +316,16 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } } - select { - case <-c.closed: - if opcode == opClose { - return n, nil + if opcode == opClose { + c.closeStateMu.Lock() + c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed) + closeReceived := c.closeReceivedErr != nil + c.closeStateMu.Unlock() + + if closeReceived && !c.casClosing() { + c.writeFrameMu.unlock() + _ = c.close() } - return n, net.ErrClosed - case c.writeTimeout <- context.Background(): } return n, nil diff --git a/ws_js.go b/ws_js.go index a8de0c63..5e324c47 100644 --- a/ws_js.go +++ b/ws_js.go @@ -12,11 +12,11 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "syscall/js" "github.com/coder/websocket/internal/bpool" "github.com/coder/websocket/internal/wsjs" - "github.com/coder/websocket/internal/xsync" ) // opcode represents a WebSocket opcode. @@ -45,7 +45,7 @@ type Conn struct { ws wsjs.WebSocket // read limit for a message in bytes. - msgReadLimit xsync.Int64 + msgReadLimit atomic.Int64 closeReadMu sync.Mutex closeReadCtx context.Context pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy