diff --git a/integrationtests/self/cancelation_test.go b/integrationtests/self/cancelation_test.go index fb966638f..a77065e32 100644 --- a/integrationtests/self/cancelation_test.go +++ b/integrationtests/self/cancelation_test.go @@ -25,7 +25,8 @@ var _ = Describe("Stream Cancelations", func() { // The server accepts a single session, and then opens numStreams unidirectional streams. // On each of these streams, it (tries to) write PRData. - runServer := func() <-chan int32 { + // When done, it sends the number of canceled streams on the channel. + runServer := func(data []byte) <-chan int32 { numCanceledStreamsChan := make(chan int32) var err error server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) @@ -44,7 +45,7 @@ var _ = Describe("Stream Cancelations", func() { defer wg.Done() str, err := sess.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) - if _, err := str.Write(PRData); err != nil { + if _, err := str.Write(data); err != nil { Expect(err).To(MatchError(&quic.StreamError{ StreamID: str.StreamID(), ErrorCode: quic.StreamErrorCode(str.StreamID()), @@ -70,7 +71,7 @@ var _ = Describe("Stream Cancelations", func() { }) It("downloads when the client immediately cancels most streams", func() { - serverCanceledCounterChan := runServer() + serverCanceledCounterChan := runServer(PRData) sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), @@ -113,7 +114,7 @@ var _ = Describe("Stream Cancelations", func() { }) It("downloads when the client cancels streams after reading from them for a bit", func() { - serverCanceledCounterChan := runServer() + serverCanceledCounterChan := runServer(PRData) sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), @@ -159,6 +160,51 @@ var _ = Describe("Stream Cancelations", func() { Expect(clientCanceledCounter).To(BeNumerically(">", numStreams/10)) Expect(numStreams - clientCanceledCounter).To(BeNumerically(">", numStreams/10)) }) + + It("allows concurrent Read and CancelRead calls", func() { + // This test is especially valuable when run with race detector, + // see https://github.com/lucas-clemente/quic-go/issues/3239. + serverCanceledCounterChan := runServer(make([]byte, 100)) // make sure the FIN is sent with the STREAM frame + + sess, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), + ) + Expect(err).ToNot(HaveOccurred()) + + var wg sync.WaitGroup + wg.Add(numStreams) + var counter int32 + for i := 0; i < numStreams; i++ { + go func() { + defer GinkgoRecover() + defer wg.Done() + str, err := sess.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + b := make([]byte, 32) + if _, err := str.Read(b); err != nil { + atomic.AddInt32(&counter, 1) + Expect(err.Error()).To(ContainSubstring("canceled with error code 1234")) + return + } + }() + go str.CancelRead(1234) + Eventually(done).Should(BeClosed()) + }() + } + wg.Wait() + Expect(sess.CloseWithError(0, "")).To(Succeed()) + numCanceled := atomic.LoadInt32(&counter) + fmt.Fprintf(GinkgoWriter, "canceled %d out of %d streams", numCanceled, numStreams) + Expect(numCanceled).ToNot(BeZero()) + Eventually(serverCanceledCounterChan).Should(Receive()) + }) }) Context("canceling the write side", func() { diff --git a/internal/flowcontrol/stream_flow_controller.go b/internal/flowcontrol/stream_flow_controller.go index 4cb94d902..aa66aef1a 100644 --- a/internal/flowcontrol/stream_flow_controller.go +++ b/internal/flowcontrol/stream_flow_controller.go @@ -109,7 +109,10 @@ func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) { } func (c *streamFlowController) Abandon() { - if unread := c.highestReceived - c.bytesRead; unread > 0 { + c.mutex.Lock() + unread := c.highestReceived - c.bytesRead + c.mutex.Unlock() + if unread > 0 { c.connection.AddBytesRead(unread) } }