diff --git a/Changelog.md b/Changelog.md index 64f47a397..881b1e414 100644 --- a/Changelog.md +++ b/Changelog.md @@ -8,6 +8,7 @@ - Enforce application protocol negotiation (via `tls.Config.NextProtos`). - Use a varint for error codes. - Add support for [quic-trace](https://github.com/google/quic-trace). +- Add a context to `Listener.Accept`, `Session.Accept{Uni}Stream` and `Session.Open{Uni}StreamSync`. ## v0.11.0 (2019-04-05) diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go index 6685115ea..77e533daf 100644 --- a/benchmark/benchmark_test.go +++ b/benchmark/benchmark_test.go @@ -2,6 +2,7 @@ package benchmark import ( "bytes" + "context" "crypto/tls" "fmt" "io" @@ -44,7 +45,7 @@ func init() { ) Expect(err).ToNot(HaveOccurred()) serverAddr <- ln.Addr() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) // wait for the client to complete the handshake before sending the data // this should not be necessary, but due to timing issues on the CIs, this is necessary to avoid sending too many undecryptable packets @@ -66,7 +67,7 @@ func init() { ) Expect(err).ToNot(HaveOccurred()) close(handshakeChan) - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) buf := &bytes.Buffer{} diff --git a/example/echo/echo.go b/example/echo/echo.go index 44a93103f..7c1ae2803 100644 --- a/example/echo/echo.go +++ b/example/echo/echo.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -35,11 +36,11 @@ func echoServer() error { if err != nil { return err } - sess, err := listener.Accept() + sess, err := listener.Accept(context.Background()) if err != nil { return err } - stream, err := sess.AcceptStream() + stream, err := sess.AcceptStream(context.Background()) if err != nil { panic(err) } @@ -58,7 +59,7 @@ func clientMain() error { return err } - stream, err := session.OpenStreamSync() + stream, err := session.OpenStreamSync(context.Background()) if err != nil { return err } diff --git a/http3/client.go b/http3/client.go index f9b66adfd..8744422e5 100644 --- a/http3/client.go +++ b/http3/client.go @@ -2,6 +2,7 @@ package http3 import ( "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -100,7 +101,7 @@ func (c *client) dial() error { func (c *client) setupSession() error { // open the control stream - str, err := c.session.OpenUniStreamSync() + str, err := c.session.OpenUniStream() if err != nil { return err } @@ -138,7 +139,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { return nil, c.handshakeErr } - str, err := c.session.OpenStreamSync() + str, err := c.session.OpenStreamSync(context.Background()) if err != nil { return nil, err } diff --git a/http3/client_test.go b/http3/client_test.go index c67a3d2b9..f8ccc71c5 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -3,6 +3,7 @@ package http3 import ( "bytes" "compress/gzip" + "context" "crypto/tls" "errors" "io" @@ -126,8 +127,8 @@ var _ = Describe("Client", func() { testErr := errors.New("stream open error") client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) session := mockquic.NewMockSession(mockCtrl) - session.EXPECT().OpenUniStreamSync().Return(nil, testErr).MaxTimes(1) - session.EXPECT().OpenStreamSync().Return(nil, testErr).MaxTimes(1) + session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1) + session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1) session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil @@ -169,7 +170,7 @@ var _ = Describe("Client", func() { controlStr.EXPECT().Write(gomock.Any()).MaxTimes(1) // SETTINGS frame str = mockquic.NewMockStream(mockCtrl) sess = mockquic.NewMockSession(mockCtrl) - sess.EXPECT().OpenUniStreamSync().Return(controlStr, nil).MaxTimes(1) + sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1) dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return sess, nil } @@ -179,7 +180,7 @@ var _ = Describe("Client", func() { }) It("sends a request", func() { - sess.EXPECT().OpenStreamSync().Return(str, nil) + sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { return buf.Write(p) @@ -200,7 +201,7 @@ var _ = Describe("Client", func() { rw := newResponseWriter(rspBuf, utils.DefaultLogger) rw.WriteHeader(418) - sess.EXPECT().OpenStreamSync().Return(str, nil) + sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) str.EXPECT().Write(gomock.Any()).AnyTimes() str.EXPECT().Close() str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { @@ -234,7 +235,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { strBuf = &bytes.Buffer{} - sess.EXPECT().OpenStreamSync().Return(str, nil) + sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) body := &mockBody{} body.SetData([]byte("request body")) var err error @@ -295,7 +296,7 @@ var _ = Describe("Client", func() { }) It("adds the gzip header to requests", func() { - sess.EXPECT().OpenStreamSync().Return(str, nil) + sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { return buf.Write(p) @@ -310,7 +311,7 @@ var _ = Describe("Client", func() { It("doesn't add gzip if the header disable it", func() { client = newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) - sess.EXPECT().OpenStreamSync().Return(str, nil) + sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { return buf.Write(p) @@ -324,7 +325,7 @@ var _ = Describe("Client", func() { }) It("decompresses the response", func() { - sess.EXPECT().OpenStreamSync().Return(str, nil) + sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} rw := newResponseWriter(buf, utils.DefaultLogger) rw.Header().Set("Content-Encoding", "gzip") @@ -348,7 +349,7 @@ var _ = Describe("Client", func() { }) It("only decompresses the response if the response contains the right content-encoding header", func() { - sess.EXPECT().OpenStreamSync().Return(str, nil) + sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} rw := newResponseWriter(buf, utils.DefaultLogger) rw.Write([]byte("not gzipped")) diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index c9e549c04..c6bfa14e8 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -2,6 +2,7 @@ package http3 import ( "bytes" + "context" "crypto/tls" "errors" "io" @@ -91,8 +92,8 @@ var _ = Describe("RoundTripper", func() { testErr := errors.New("test err") req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) - session.EXPECT().OpenUniStreamSync().AnyTimes().Return(nil, testErr) - session.EXPECT().OpenStreamSync().Return(nil, testErr) + session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) + session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) }) _, err = rt.RoundTrip(req) Expect(err).To(MatchError(testErr)) @@ -128,8 +129,8 @@ var _ = Describe("RoundTripper", func() { It("reuses existing clients", func() { closed := make(chan struct{}) testErr := errors.New("test err") - session.EXPECT().OpenUniStreamSync().AnyTimes().Return(nil, testErr) - session.EXPECT().OpenStreamSync().Return(nil, testErr).Times(2) + session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) + session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2) session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) }) req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) Expect(err).ToNot(HaveOccurred()) diff --git a/http3/server.go b/http3/server.go index 917b0bae6..f87cfe29f 100644 --- a/http3/server.go +++ b/http3/server.go @@ -2,6 +2,7 @@ package http3 import ( "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -114,7 +115,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { s.listenerMutex.Unlock() for { - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) if err != nil { return err } @@ -127,7 +128,7 @@ func (s *Server) handleConn(sess quic.Session) { decoder := qpack.NewDecoder(nil) // send a SETTINGS frame - str, err := sess.OpenUniStreamSync() + str, err := sess.OpenUniStream() if err != nil { s.logger.Debugf("Opening the control stream failed.") return @@ -137,7 +138,7 @@ func (s *Server) handleConn(sess quic.Session) { str.Write(buf.Bytes()) for { - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) if err != nil { s.logger.Debugf("Accepting stream failed: %s", err) return diff --git a/integrationtests/self/cancelation_test.go b/integrationtests/self/cancelation_test.go index dffdea202..2724e4c59 100644 --- a/integrationtests/self/cancelation_test.go +++ b/integrationtests/self/cancelation_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io" "io/ioutil" @@ -8,6 +9,7 @@ import ( "net" "sync" "sync/atomic" + "time" quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/integrationtests/tools/testserver" @@ -33,13 +35,13 @@ var _ = Describe("Stream Cancelations", func() { defer GinkgoRecover() var wg sync.WaitGroup wg.Add(numStreams) - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.OpenUniStreamSync() + str, err := sess.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) if _, err = str.Write(testserver.PRData); err != nil { Expect(err).To(MatchError(fmt.Sprintf("Stream %d was reset with error code %d", str.StreamID(), str.StreamID()))) @@ -75,7 +77,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream() + str, err := sess.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) // cancel around 2/3 of the streams if rand.Int31()%3 != 0 { @@ -119,7 +121,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream() + str, err := sess.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) // only read some data from about 1/3 of the streams if rand.Int31()%3 != 0 { @@ -167,7 +169,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream() + str, err := sess.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := ioutil.ReadAll(str) if err != nil { @@ -196,12 +198,12 @@ var _ = Describe("Stream Cancelations", func() { var canceledCounter int32 go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { defer GinkgoRecover() - str, err := sess.OpenUniStreamSync() + str, err := sess.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) // cancel about 2/3 of the streams if rand.Int31()%3 != 0 { @@ -227,12 +229,12 @@ var _ = Describe("Stream Cancelations", func() { var canceledCounter int32 go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { defer GinkgoRecover() - str, err := sess.OpenUniStreamSync() + str, err := sess.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) // only write some data from about 1/3 of the streams, then cancel if rand.Int31()%3 != 0 { @@ -265,13 +267,13 @@ var _ = Describe("Stream Cancelations", func() { defer GinkgoRecover() var wg sync.WaitGroup wg.Add(numStreams) - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.OpenUniStreamSync() + str, err := sess.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) // cancel about half of the streams if rand.Int31()%2 == 0 { @@ -303,7 +305,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream() + str, err := sess.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) // cancel around half of the streams if rand.Int31()%2 == 0 { @@ -339,13 +341,13 @@ var _ = Describe("Stream Cancelations", func() { defer GinkgoRecover() var wg sync.WaitGroup wg.Add(numStreams) - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.OpenUniStreamSync() + str, err := sess.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) // cancel about half of the streams length := len(testserver.PRData) @@ -382,7 +384,7 @@ var _ = Describe("Stream Cancelations", func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream() + str, err := sess.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) r := io.Reader(str) @@ -418,4 +420,143 @@ var _ = Describe("Stream Cancelations", func() { Expect(server.Close()).To(Succeed()) }) }) + + Context("canceling the context", func() { + It("downloads data when the receiving peer cancels the context for accepting streams", func() { + server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil) + Expect(err).ToNot(HaveOccurred()) + + go func() { + defer GinkgoRecover() + sess, err := server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + ticker := time.NewTicker(5 * time.Millisecond) + for i := 0; i < numStreams; i++ { + <-ticker.C + go func() { + defer GinkgoRecover() + str, err := sess.OpenUniStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(testserver.PRData) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + } + }() + + sess, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + &quic.Config{MaxIncomingUniStreams: numStreams / 3}, + ) + Expect(err).ToNot(HaveOccurred()) + + var numToAccept int32 + var counter int32 + var wg sync.WaitGroup + wg.Add(numStreams) + for atomic.LoadInt32(&numToAccept) < numStreams { + ctx, cancel := context.WithCancel(context.Background()) + // cancel accepting half of the streams + if rand.Int31()%2 == 0 { + cancel() + } else { + atomic.AddInt32(&numToAccept, 1) + defer cancel() + } + + go func() { + defer GinkgoRecover() + str, err := sess.AcceptUniStream(ctx) + if err != nil { + atomic.AddInt32(&counter, 1) + Expect(err).To(MatchError("context canceled")) + return + } + data, err := ioutil.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(testserver.PRData)) + wg.Done() + }() + } + wg.Wait() + + count := atomic.LoadInt32(&counter) + fmt.Fprintf(GinkgoWriter, "Canceled AcceptStream %d times\n", count) + Expect(count).To(BeNumerically(">", numStreams/2)) + Expect(sess.Close()).To(Succeed()) + Expect(server.Close()).To(Succeed()) + }) + + It("downloads data when the sending peer cancels the context for opening streams", func() { + server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil) + Expect(err).ToNot(HaveOccurred()) + + var numCanceled int32 + go func() { + defer GinkgoRecover() + sess, err := server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + + var numOpened int + ticker := time.NewTicker(250 * time.Microsecond) + for numOpened < numStreams { + <-ticker.C + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // cancel accepting half of the streams + shouldCancel := rand.Int31()%2 == 0 + + if shouldCancel { + time.AfterFunc(5*time.Millisecond, cancel) + } + str, err := sess.OpenUniStreamSync(ctx) + if err != nil { + atomic.AddInt32(&numCanceled, 1) + Expect(err).To(MatchError("context canceled")) + continue + } + numOpened++ + go func(str quic.SendStream) { + defer GinkgoRecover() + _, err = str.Write(testserver.PRData) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }(str) + } + }() + + sess, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + &quic.Config{ + MaxIncomingUniStreams: 5, + }, + ) + Expect(err).ToNot(HaveOccurred()) + + var wg sync.WaitGroup + wg.Add(numStreams) + ticker := time.NewTicker(10 * time.Millisecond) + for i := 0; i < numStreams; i++ { + <-ticker.C + go func() { + defer GinkgoRecover() + str, err := sess.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(testserver.PRData)) + wg.Done() + }() + } + wg.Wait() + + count := atomic.LoadInt32(&numCanceled) + fmt.Fprintf(GinkgoWriter, "Canceled OpenStreamSync %d times\n", count) + Expect(count).To(BeNumerically(">", numStreams/5)) + Expect(sess.Close()).To(Succeed()) + Expect(server.Close()).To(Succeed()) + }) + }) }) diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index 8128686a9..4b96e1dab 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io/ioutil" "math/rand" @@ -25,7 +26,7 @@ var _ = Describe("Connection ID lengths tests", func() { go func() { defer GinkgoRecover() for { - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) if err != nil { return } @@ -51,7 +52,7 @@ var _ = Describe("Connection ID lengths tests", func() { ) Expect(err).ToNot(HaveOccurred()) defer cl.Close() - str, err := cl.AcceptStream() + str, err := cl.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := ioutil.ReadAll(str) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/deadline_test.go b/integrationtests/self/deadline_test.go index 4b28e8112..fc063793b 100644 --- a/integrationtests/self/deadline_test.go +++ b/integrationtests/self/deadline_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io/ioutil" "net" @@ -27,9 +28,9 @@ var _ = Describe("Stream deadline tests", func() { acceptedStream := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - serverStr, err = sess.AcceptStream() + serverStr, err = sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) _, err = serverStr.Read([]byte{0}) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/drop_test.go b/integrationtests/self/drop_test.go index bf3265ae3..5a20ccd42 100644 --- a/integrationtests/self/drop_test.go +++ b/integrationtests/self/drop_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "math/rand" "net" @@ -88,7 +89,7 @@ var _ = Describe("Drop Tests", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) str, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -109,7 +110,7 @@ var _ = Describe("Drop Tests", func() { ) Expect(err).ToNot(HaveOccurred()) defer sess.Close() - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := uint8(1); i <= numMessages; i++ { b := []byte{0} diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index 5a6a8dff4..724a72cc8 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" mrand "math/rand" "net" @@ -57,10 +58,10 @@ var _ = Describe("Handshake drop tests", func() { serverSessionChan := make(chan quic.Session) go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) defer sess.Close() - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) b := make([]byte, 6) _, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b) @@ -92,7 +93,7 @@ var _ = Describe("Handshake drop tests", func() { serverSessionChan := make(chan quic.Session) go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) str, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -106,7 +107,7 @@ var _ = Describe("Handshake drop tests", func() { &quic.Config{Versions: []protocol.VersionNumber{version}}, ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) b := make([]byte, 6) _, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b) @@ -126,7 +127,7 @@ var _ = Describe("Handshake drop tests", func() { serverSessionChan := make(chan quic.Session) go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) serverSessionChan <- sess }() diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index 16a993a50..01053295a 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "crypto/tls" "fmt" "net" @@ -56,8 +57,7 @@ var _ = Describe("Handshake RTT tests", func() { defer GinkgoRecover() defer close(acceptStopped) for { - _, err := server.Accept() - if err != nil { + if _, err := server.Accept(context.Background()); err != nil { return } } diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 260b2bff2..0192abea6 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "crypto/tls" "fmt" "net" @@ -50,7 +51,7 @@ var _ = Describe("Handshake tests", func() { defer GinkgoRecover() defer close(acceptStopped) for { - if _, err := server.Accept(); err != nil { + if _, err := server.Accept(context.Background()); err != nil { return } } @@ -158,7 +159,7 @@ var _ = Describe("Handshake tests", func() { errChan := make(chan error) go func() { defer GinkgoRecover() - _, err := sess.AcceptStream() + _, err := sess.AcceptStream(context.Background()) errChan <- err }() Eventually(errChan).Should(Receive(&err)) @@ -236,7 +237,7 @@ var _ = Describe("Handshake tests", func() { Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ServerBusy)) // now accept one session, freeing one spot in the queue - _, err = server.Accept() + _, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) // dial again, and expect that this dial succeeds sess, err := dial() @@ -289,7 +290,7 @@ var _ = Describe("Handshake tests", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) cs := sess.ConnectionState() Expect(cs.NegotiatedProtocol).To(Equal(alpn)) diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index 3e818074d..a524b29bf 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io/ioutil" "net" @@ -25,7 +26,7 @@ var _ = Describe("Multiplexing", func() { go func() { defer GinkgoRecover() for { - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) if err != nil { return } @@ -50,7 +51,7 @@ var _ = Describe("Multiplexing", func() { &quic.Config{Versions: []protocol.VersionNumber{version}}, ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := ioutil.ReadAll(str) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/resumption_test.go b/integrationtests/self/resumption_test.go index 89059fce9..fad7fe090 100644 --- a/integrationtests/self/resumption_test.go +++ b/integrationtests/self/resumption_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "crypto/tls" "fmt" "net" @@ -55,11 +56,11 @@ var _ = Describe("TLS session resumption", func() { go func() { defer close(done) defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(sess.ConnectionState().DidResume).To(BeFalse()) - sess, err = server.Accept() + sess, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(sess.ConnectionState().DidResume).To(BeTrue()) }() diff --git a/integrationtests/self/rtt_test.go b/integrationtests/self/rtt_test.go index 34a54e68e..eff2108a7 100644 --- a/integrationtests/self/rtt_test.go +++ b/integrationtests/self/rtt_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io/ioutil" "net" @@ -42,7 +43,7 @@ var _ = Describe("non-zero RTT", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) str, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -67,7 +68,7 @@ var _ = Describe("non-zero RTT", func() { &quic.Config{Versions: []protocol.VersionNumber{version}}, ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := ioutil.ReadAll(str) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index 22745cfe0..aae2ed74b 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "math/rand" "net" @@ -33,7 +34,7 @@ var _ = Describe("Stateless Resets", func() { go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) str, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -63,7 +64,7 @@ var _ = Describe("Stateless Resets", func() { }, ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data := make([]byte, 6) _, err = str.Read(data) @@ -86,7 +87,7 @@ var _ = Describe("Stateless Resets", func() { acceptStopped := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := ln2.Accept() + _, err := ln2.Accept(context.Background()) Expect(err).To(HaveOccurred()) close(acceptStopped) }() diff --git a/integrationtests/self/stream_test.go b/integrationtests/self/stream_test.go index 72df4d2f1..609bfc8a7 100644 --- a/integrationtests/self/stream_test.go +++ b/integrationtests/self/stream_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io/ioutil" "net" @@ -46,7 +47,7 @@ var _ = Describe("Bidirectional streams", func() { var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { - str, err := sess.OpenStreamSync() + str, err := sess.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) data := testserver.GeneratePRData(25 * i) go func() { @@ -70,7 +71,7 @@ var _ = Describe("Bidirectional streams", func() { var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { - str, err := sess.AcceptStream() + str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() @@ -92,7 +93,7 @@ var _ = Describe("Bidirectional streams", func() { go func() { defer GinkgoRecover() var err error - sess, err = server.Accept() + sess, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) runReceivingPeer(sess) }() @@ -109,7 +110,7 @@ var _ = Describe("Bidirectional streams", func() { It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) runSendingPeer(sess) sess.Close() @@ -129,7 +130,7 @@ var _ = Describe("Bidirectional streams", func() { done1 := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) go func() { diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 3c50a277b..8a752e575 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -70,7 +70,7 @@ var _ = Describe("Timeout tests", func() { go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) str, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -95,7 +95,7 @@ var _ = Describe("Timeout tests", func() { &quic.Config{IdleTimeout: idleTimeout}, ) Expect(err).ToNot(HaveOccurred()) - strIn, err := sess.AcceptStream() + strIn, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) strOut, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -116,9 +116,9 @@ var _ = Describe("Timeout tests", func() { checkTimeoutError(err) _, err = sess.OpenUniStream() checkTimeoutError(err) - _, err = sess.AcceptStream() + _, err = sess.AcceptStream(context.Background()) checkTimeoutError(err) - _, err = sess.AcceptUniStream() + _, err = sess.AcceptUniStream(context.Background()) checkTimeoutError(err) }) @@ -146,9 +146,9 @@ var _ = Describe("Timeout tests", func() { serverSessionClosed := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - sess.AcceptStream() // blocks until the session is closed + sess.AcceptStream(context.Background()) // blocks until the session is closed close(serverSessionClosed) }() @@ -162,7 +162,7 @@ var _ = Describe("Timeout tests", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := sess.AcceptStream() + _, err := sess.AcceptStream(context.Background()) checkTimeoutError(err) close(done) }() @@ -187,9 +187,9 @@ var _ = Describe("Timeout tests", func() { serverSessionClosed := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - sess.AcceptStream() // blocks until the session is closed + sess.AcceptStream(context.Background()) // blocks until the session is closed close(serverSessionClosed) }() @@ -212,7 +212,7 @@ var _ = Describe("Timeout tests", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := sess.AcceptStream() + _, err := sess.AcceptStream(context.Background()) checkTimeoutError(err) close(done) }() diff --git a/integrationtests/self/uni_stream_test.go b/integrationtests/self/uni_stream_test.go index 1d5a044c2..f2be19a98 100644 --- a/integrationtests/self/uni_stream_test.go +++ b/integrationtests/self/uni_stream_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io/ioutil" "net" @@ -41,7 +42,7 @@ var _ = Describe("Unidirectional Streams", func() { runSendingPeer := func(sess quic.Session) { for i := 0; i < numStreams; i++ { - str, err := sess.OpenUniStreamSync() + str, err := sess.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() @@ -56,7 +57,7 @@ var _ = Describe("Unidirectional Streams", func() { var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { - str, err := sess.AcceptUniStream() + str, err := sess.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() @@ -72,7 +73,7 @@ var _ = Describe("Unidirectional Streams", func() { It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() { go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) runReceivingPeer(sess) sess.Close() @@ -91,7 +92,7 @@ var _ = Describe("Unidirectional Streams", func() { It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) runSendingPeer(sess) }() @@ -109,7 +110,7 @@ var _ = Describe("Unidirectional Streams", func() { done1 := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) go func() { diff --git a/interface.go b/interface.go index 5362e3304..8059e25b5 100644 --- a/interface.go +++ b/interface.go @@ -127,11 +127,11 @@ type Session interface { // AcceptStream returns the next stream opened by the peer, blocking until one is available. // If the session was closed due to a timeout, the error satisfies // the net.Error interface, and Timeout() will be true. - AcceptStream() (Stream, error) + AcceptStream(context.Context) (Stream, error) // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available. // If the session was closed due to a timeout, the error satisfies // the net.Error interface, and Timeout() will be true. - AcceptUniStream() (ReceiveStream, error) + AcceptUniStream(context.Context) (ReceiveStream, error) // OpenStream opens a new bidirectional QUIC stream. // There is no signaling to the peer about new streams: // The peer can only accept the stream after data has been sent on the stream. @@ -143,7 +143,7 @@ type Session interface { // It blocks until a new stream can be opened. // If the error is non-nil, it satisfies the net.Error interface. // If the session was closed due to a timeout, Timeout() will be true. - OpenStreamSync() (Stream, error) + OpenStreamSync(context.Context) (Stream, error) // OpenUniStream opens a new outgoing unidirectional QUIC stream. // If the error is non-nil, it satisfies the net.Error interface. // When reaching the peer's stream limit, Temporary() will be true. @@ -153,7 +153,7 @@ type Session interface { // It blocks until a new stream can be opened. // If the error is non-nil, it satisfies the net.Error interface. // If the session was closed due to a timeout, Timeout() will be true. - OpenUniStreamSync() (SendStream, error) + OpenUniStreamSync(context.Context) (SendStream, error) // LocalAddr returns the local address. LocalAddr() net.Addr // RemoteAddr returns the address of the peer. @@ -232,5 +232,5 @@ type Listener interface { // Addr returns the local network addr that the server is listening on. Addr() net.Addr // Accept returns new sessions. It should be called in a loop. - Accept() (Session, error) + Accept(context.Context) (Session, error) } diff --git a/internal/mocks/quic/session.go b/internal/mocks/quic/session.go index ae6c6b22b..b65f484f4 100644 --- a/internal/mocks/quic/session.go +++ b/internal/mocks/quic/session.go @@ -39,33 +39,33 @@ func (m *MockSession) EXPECT() *MockSessionMockRecorder { } // AcceptStream mocks base method -func (m *MockSession) AcceptStream() (quic_go.Stream, error) { +func (m *MockSession) AcceptStream(arg0 context.Context) (quic_go.Stream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptStream") + ret := m.ctrl.Call(m, "AcceptStream", arg0) ret0, _ := ret[0].(quic_go.Stream) ret1, _ := ret[1].(error) return ret0, ret1 } // AcceptStream indicates an expected call of AcceptStream -func (mr *MockSessionMockRecorder) AcceptStream() *gomock.Call { +func (mr *MockSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream), arg0) } // AcceptUniStream mocks base method -func (m *MockSession) AcceptUniStream() (quic_go.ReceiveStream, error) { +func (m *MockSession) AcceptUniStream(arg0 context.Context) (quic_go.ReceiveStream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptUniStream") + ret := m.ctrl.Call(m, "AcceptUniStream", arg0) ret0, _ := ret[0].(quic_go.ReceiveStream) ret1, _ := ret[1].(error) return ret0, ret1 } // AcceptUniStream indicates an expected call of AcceptUniStream -func (mr *MockSessionMockRecorder) AcceptUniStream() *gomock.Call { +func (mr *MockSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream), arg0) } // Close mocks base method @@ -154,18 +154,18 @@ func (mr *MockSessionMockRecorder) OpenStream() *gomock.Call { } // OpenStreamSync mocks base method -func (m *MockSession) OpenStreamSync() (quic_go.Stream, error) { +func (m *MockSession) OpenStreamSync(arg0 context.Context) (quic_go.Stream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStreamSync") + ret := m.ctrl.Call(m, "OpenStreamSync", arg0) ret0, _ := ret[0].(quic_go.Stream) ret1, _ := ret[1].(error) return ret0, ret1 } // OpenStreamSync indicates an expected call of OpenStreamSync -func (mr *MockSessionMockRecorder) OpenStreamSync() *gomock.Call { +func (mr *MockSessionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockSession)(nil).OpenStreamSync)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockSession)(nil).OpenStreamSync), arg0) } // OpenUniStream mocks base method @@ -184,18 +184,18 @@ func (mr *MockSessionMockRecorder) OpenUniStream() *gomock.Call { } // OpenUniStreamSync mocks base method -func (m *MockSession) OpenUniStreamSync() (quic_go.SendStream, error) { +func (m *MockSession) OpenUniStreamSync(arg0 context.Context) (quic_go.SendStream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenUniStreamSync") + ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) ret0, _ := ret[0].(quic_go.SendStream) ret1, _ := ret[1].(error) return ret0, ret1 } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync -func (mr *MockSessionMockRecorder) OpenUniStreamSync() *gomock.Call { +func (mr *MockSessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockSession)(nil).OpenUniStreamSync)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockSession)(nil).OpenUniStreamSync), arg0) } // RemoteAddr mocks base method diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index 833331885..30bcd90b5 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -38,33 +38,33 @@ func (m *MockQuicSession) EXPECT() *MockQuicSessionMockRecorder { } // AcceptStream mocks base method -func (m *MockQuicSession) AcceptStream() (Stream, error) { +func (m *MockQuicSession) AcceptStream(arg0 context.Context) (Stream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptStream") + ret := m.ctrl.Call(m, "AcceptStream", arg0) ret0, _ := ret[0].(Stream) ret1, _ := ret[1].(error) return ret0, ret1 } // AcceptStream indicates an expected call of AcceptStream -func (mr *MockQuicSessionMockRecorder) AcceptStream() *gomock.Call { +func (mr *MockQuicSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream), arg0) } // AcceptUniStream mocks base method -func (m *MockQuicSession) AcceptUniStream() (ReceiveStream, error) { +func (m *MockQuicSession) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptUniStream") + ret := m.ctrl.Call(m, "AcceptUniStream", arg0) ret0, _ := ret[0].(ReceiveStream) ret1, _ := ret[1].(error) return ret0, ret1 } // AcceptUniStream indicates an expected call of AcceptUniStream -func (mr *MockQuicSessionMockRecorder) AcceptUniStream() *gomock.Call { +func (mr *MockQuicSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream), arg0) } // Close mocks base method @@ -167,18 +167,18 @@ func (mr *MockQuicSessionMockRecorder) OpenStream() *gomock.Call { } // OpenStreamSync mocks base method -func (m *MockQuicSession) OpenStreamSync() (Stream, error) { +func (m *MockQuicSession) OpenStreamSync(arg0 context.Context) (Stream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStreamSync") + ret := m.ctrl.Call(m, "OpenStreamSync", arg0) ret0, _ := ret[0].(Stream) ret1, _ := ret[1].(error) return ret0, ret1 } // OpenStreamSync indicates an expected call of OpenStreamSync -func (mr *MockQuicSessionMockRecorder) OpenStreamSync() *gomock.Call { +func (mr *MockQuicSessionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenStreamSync)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenStreamSync), arg0) } // OpenUniStream mocks base method @@ -197,18 +197,18 @@ func (mr *MockQuicSessionMockRecorder) OpenUniStream() *gomock.Call { } // OpenUniStreamSync mocks base method -func (m *MockQuicSession) OpenUniStreamSync() (SendStream, error) { +func (m *MockQuicSession) OpenUniStreamSync(arg0 context.Context) (SendStream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenUniStreamSync") + ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) ret0, _ := ret[0].(SendStream) ret1, _ := ret[1].(error) return ret0, ret1 } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync -func (mr *MockQuicSessionMockRecorder) OpenUniStreamSync() *gomock.Call { +func (mr *MockQuicSessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenUniStreamSync)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenUniStreamSync), arg0) } // RemoteAddr mocks base method diff --git a/mock_stream_manager_test.go b/mock_stream_manager_test.go index 1f9656609..2732fda41 100644 --- a/mock_stream_manager_test.go +++ b/mock_stream_manager_test.go @@ -5,6 +5,7 @@ package quic import ( + context "context" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -37,33 +38,33 @@ func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder { } // AcceptStream mocks base method -func (m *MockStreamManager) AcceptStream() (Stream, error) { +func (m *MockStreamManager) AcceptStream(arg0 context.Context) (Stream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptStream") + ret := m.ctrl.Call(m, "AcceptStream", arg0) ret0, _ := ret[0].(Stream) ret1, _ := ret[1].(error) return ret0, ret1 } // AcceptStream indicates an expected call of AcceptStream -func (mr *MockStreamManagerMockRecorder) AcceptStream() *gomock.Call { +func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0) } // AcceptUniStream mocks base method -func (m *MockStreamManager) AcceptUniStream() (ReceiveStream, error) { +func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptUniStream") + ret := m.ctrl.Call(m, "AcceptUniStream", arg0) ret0, _ := ret[0].(ReceiveStream) ret1, _ := ret[1].(error) return ret0, ret1 } // AcceptUniStream indicates an expected call of AcceptUniStream -func (mr *MockStreamManagerMockRecorder) AcceptUniStream() *gomock.Call { +func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0) } // CloseWithError mocks base method @@ -152,18 +153,18 @@ func (mr *MockStreamManagerMockRecorder) OpenStream() *gomock.Call { } // OpenStreamSync mocks base method -func (m *MockStreamManager) OpenStreamSync() (Stream, error) { +func (m *MockStreamManager) OpenStreamSync(arg0 context.Context) (Stream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenStreamSync") + ret := m.ctrl.Call(m, "OpenStreamSync", arg0) ret0, _ := ret[0].(Stream) ret1, _ := ret[1].(error) return ret0, ret1 } // OpenStreamSync indicates an expected call of OpenStreamSync -func (mr *MockStreamManagerMockRecorder) OpenStreamSync() *gomock.Call { +func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync), arg0) } // OpenUniStream mocks base method @@ -182,18 +183,18 @@ func (mr *MockStreamManagerMockRecorder) OpenUniStream() *gomock.Call { } // OpenUniStreamSync mocks base method -func (m *MockStreamManager) OpenUniStreamSync() (SendStream, error) { +func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (SendStream, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenUniStreamSync") + ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) ret0, _ := ret[0].(SendStream) ret1, _ := ret[1].(error) return ret0, ret1 } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync -func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync() *gomock.Call { +func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0) } // UpdateLimits mocks base method diff --git a/server.go b/server.go index c080d063f..f6510bdc1 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -284,9 +285,11 @@ func populateServerConfig(config *Config) *Config { } // Accept returns newly openend sessions -func (s *server) Accept() (Session, error) { +func (s *server) Accept(ctx context.Context) (Session, error) { var sess Session select { + case <-ctx.Done(): + return nil, ctx.Err() case sess = <-s.sessionQueue: return sess, nil case <-s.errorChan: diff --git a/server_test.go b/server_test.go index 008d7dcc0..0a966688d 100644 --- a/server_test.go +++ b/server_test.go @@ -434,7 +434,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - serv.Accept() + serv.Accept(context.Background()) close(done) }() Consistently(done).ShouldNot(BeClosed()) @@ -461,7 +461,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := serv.Accept() + _, err := serv.Accept(context.Background()) Expect(err).To(MatchError(testErr)) close(done) }() @@ -474,18 +474,33 @@ var _ = Describe("Server", func() { testErr := errors.New("test err") serv.setCloseError(testErr) for i := 0; i < 3; i++ { - _, err := serv.Accept() + _, err := serv.Accept(context.Background()) Expect(err).To(MatchError(testErr)) } }) + It("returns when the context is canceled", func() { + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := serv.Accept(ctx) + Expect(err).To(MatchError("context canceled")) + close(done) + }() + + Consistently(done).ShouldNot(BeClosed()) + cancel() + Eventually(done).Should(BeClosed()) + }) + It("accepts new sessions when the handshake completes", func() { sess := NewMockQuicSession(mockCtrl) done := make(chan struct{}) go func() { defer GinkgoRecover() - s, err := serv.Accept() + s, err := serv.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(s).To(Equal(sess)) close(done) diff --git a/session.go b/session.go index 790f82660..8ae7ace9d 100644 --- a/session.go +++ b/session.go @@ -37,10 +37,10 @@ type streamManager interface { GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) OpenStream() (Stream, error) OpenUniStream() (SendStream, error) - OpenStreamSync() (Stream, error) - OpenUniStreamSync() (SendStream, error) - AcceptStream() (Stream, error) - AcceptUniStream() (ReceiveStream, error) + OpenStreamSync(context.Context) (Stream, error) + OpenUniStreamSync(context.Context) (SendStream, error) + AcceptStream(context.Context) (Stream, error) + AcceptUniStream(context.Context) (ReceiveStream, error) DeleteStream(protocol.StreamID) error UpdateLimits(*handshake.TransportParameters) error HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error @@ -1233,12 +1233,12 @@ func (s *session) logPacket(packet *packedPacket) { } // AcceptStream returns the next stream openend by the peer -func (s *session) AcceptStream() (Stream, error) { - return s.streamsMap.AcceptStream() +func (s *session) AcceptStream(ctx context.Context) (Stream, error) { + return s.streamsMap.AcceptStream(ctx) } -func (s *session) AcceptUniStream() (ReceiveStream, error) { - return s.streamsMap.AcceptUniStream() +func (s *session) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { + return s.streamsMap.AcceptUniStream(ctx) } // OpenStream opens a stream @@ -1246,16 +1246,16 @@ func (s *session) OpenStream() (Stream, error) { return s.streamsMap.OpenStream() } -func (s *session) OpenStreamSync() (Stream, error) { - return s.streamsMap.OpenStreamSync() +func (s *session) OpenStreamSync(ctx context.Context) (Stream, error) { + return s.streamsMap.OpenStreamSync(ctx) } func (s *session) OpenUniStream() (SendStream, error) { return s.streamsMap.OpenUniStream() } -func (s *session) OpenUniStreamSync() (SendStream, error) { - return s.streamsMap.OpenUniStreamSync() +func (s *session) OpenUniStreamSync(ctx context.Context) (SendStream, error) { + return s.streamsMap.OpenUniStreamSync(ctx) } func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController { diff --git a/session_test.go b/session_test.go index b721a21c8..9aca5f0cc 100644 --- a/session_test.go +++ b/session_test.go @@ -367,14 +367,6 @@ var _ = Describe("Session", func() { Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242))) }) - It("accepts new streams", func() { - mstr := NewMockStreamI(mockCtrl) - streamManager.EXPECT().AcceptStream().Return(mstr, nil) - str, err := sess.AcceptStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(mstr)) - }) - Context("closing", func() { var ( runErr error @@ -1431,8 +1423,8 @@ var _ = Describe("Session", func() { It("opens streams synchronously", func() { mstr := NewMockStreamI(mockCtrl) - streamManager.EXPECT().OpenStreamSync().Return(mstr, nil) - str, err := sess.OpenStreamSync() + streamManager.EXPECT().OpenStreamSync(context.Background()).Return(mstr, nil) + str, err := sess.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal(mstr)) }) @@ -1447,24 +1439,28 @@ var _ = Describe("Session", func() { It("opens unidirectional streams synchronously", func() { mstr := NewMockSendStreamI(mockCtrl) - streamManager.EXPECT().OpenUniStreamSync().Return(mstr, nil) - str, err := sess.OpenUniStreamSync() + streamManager.EXPECT().OpenUniStreamSync(context.Background()).Return(mstr, nil) + str, err := sess.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal(mstr)) }) It("accepts streams", func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() mstr := NewMockStreamI(mockCtrl) - streamManager.EXPECT().AcceptStream().Return(mstr, nil) - str, err := sess.AcceptStream() + streamManager.EXPECT().AcceptStream(ctx).Return(mstr, nil) + str, err := sess.AcceptStream(ctx) Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal(mstr)) }) It("accepts unidirectional streams", func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() mstr := NewMockReceiveStreamI(mockCtrl) - streamManager.EXPECT().AcceptUniStream().Return(mstr, nil) - str, err := sess.AcceptUniStream() + streamManager.EXPECT().AcceptUniStream(ctx).Return(mstr, nil) + str, err := sess.AcceptUniStream(ctx) Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal(mstr)) }) diff --git a/streams_map.go b/streams_map.go index cc53efd16..4dc8fc313 100644 --- a/streams_map.go +++ b/streams_map.go @@ -1,6 +1,7 @@ package quic import ( + "context" "errors" "fmt" "net" @@ -108,8 +109,8 @@ func (m *streamsMap) OpenStream() (Stream, error) { return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) } -func (m *streamsMap) OpenStreamSync() (Stream, error) { - str, err := m.outgoingBidiStreams.OpenStreamSync() +func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) { + str, err := m.outgoingBidiStreams.OpenStreamSync(ctx) return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) } @@ -118,18 +119,18 @@ func (m *streamsMap) OpenUniStream() (SendStream, error) { return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) } -func (m *streamsMap) OpenUniStreamSync() (SendStream, error) { - str, err := m.outgoingUniStreams.OpenStreamSync() +func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) { + str, err := m.outgoingUniStreams.OpenStreamSync(ctx) return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) } -func (m *streamsMap) AcceptStream() (Stream, error) { - str, err := m.incomingBidiStreams.AcceptStream() +func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) { + str, err := m.incomingBidiStreams.AcceptStream(ctx) return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite()) } -func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) { - str, err := m.incomingUniStreams.AcceptStream() +func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { + str, err := m.incomingUniStreams.AcceptStream(ctx) return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite()) } diff --git a/streams_map_incoming_bidi.go b/streams_map_incoming_bidi.go index f0aad6a27..f24b9ec22 100644 --- a/streams_map_incoming_bidi.go +++ b/streams_map_incoming_bidi.go @@ -5,6 +5,7 @@ package quic import ( + "context" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -12,8 +13,8 @@ import ( ) type incomingBidiStreamsMap struct { - mutex sync.RWMutex - cond sync.Cond + mutex sync.RWMutex + newStreamChan chan struct{} streams map[protocol.StreamNum]streamI // When a stream is deleted before it was accepted, we can't delete it immediately. @@ -36,9 +37,9 @@ func newIncomingBidiStreamsMap( newStream func(protocol.StreamNum) streamI, maxStreams uint64, queueControlFrame func(wire.Frame), - // streamNumToID func(protocol.StreamNum) protocol.StreamID, ) *incomingBidiStreamsMap { - m := &incomingBidiStreamsMap{ + return &incomingBidiStreamsMap{ + newStreamChan: make(chan struct{}), streams: make(map[protocol.StreamNum]streamI), streamsToDelete: make(map[protocol.StreamNum]struct{}), maxStream: protocol.StreamNum(maxStreams), @@ -47,38 +48,43 @@ func newIncomingBidiStreamsMap( nextStreamToOpen: 1, nextStreamToAccept: 1, queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, - // streamNumToID: streamNumToID, } - m.cond.L = &m.mutex - return m } -func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) { +func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) { m.mutex.Lock() - defer m.mutex.Unlock() var num protocol.StreamNum var str streamI for { num = m.nextStreamToAccept - var ok bool if m.closeErr != nil { + m.mutex.Unlock() return nil, m.closeErr } + var ok bool str, ok = m.streams[num] if ok { break } - m.cond.Wait() + m.mutex.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.newStreamChan: + } + m.mutex.Lock() } m.nextStreamToAccept++ // If this stream was completed before being accepted, we can delete it now. if _, ok := m.streamsToDelete[num]; ok { delete(m.streamsToDelete, num) if err := m.deleteStream(num); err != nil { + m.mutex.Unlock() return nil, err } } + m.mutex.Unlock() return str, nil } @@ -111,7 +117,10 @@ func (m *incomingBidiStreamsMap) GetOrOpenStream(num protocol.StreamNum) (stream // * highestStream is only modified by this function for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { m.streams[newNum] = m.newStream(newNum) - m.cond.Signal() + select { + case m.newStreamChan <- struct{}{}: + default: + } } m.nextStreamToOpen = num + 1 s := m.streams[num] @@ -167,5 +176,5 @@ func (m *incomingBidiStreamsMap) CloseWithError(err error) { str.closeForShutdown(err) } m.mutex.Unlock() - m.cond.Broadcast() + close(m.newStreamChan) } diff --git a/streams_map_incoming_generic.go b/streams_map_incoming_generic.go index c7485aca3..f8ace8bfb 100644 --- a/streams_map_incoming_generic.go +++ b/streams_map_incoming_generic.go @@ -1,6 +1,7 @@ package quic import ( + "context" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -10,8 +11,8 @@ import ( //go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi" //go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni" type incomingItemsMap struct { - mutex sync.RWMutex - cond sync.Cond + mutex sync.RWMutex + newStreamChan chan struct{} streams map[protocol.StreamNum]item // When a stream is deleted before it was accepted, we can't delete it immediately. @@ -34,9 +35,9 @@ func newIncomingItemsMap( newStream func(protocol.StreamNum) item, maxStreams uint64, queueControlFrame func(wire.Frame), - // streamNumToID func(protocol.StreamNum) protocol.StreamID, ) *incomingItemsMap { - m := &incomingItemsMap{ + return &incomingItemsMap{ + newStreamChan: make(chan struct{}), streams: make(map[protocol.StreamNum]item), streamsToDelete: make(map[protocol.StreamNum]struct{}), maxStream: protocol.StreamNum(maxStreams), @@ -45,38 +46,43 @@ func newIncomingItemsMap( nextStreamToOpen: 1, nextStreamToAccept: 1, queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, - // streamNumToID: streamNumToID, } - m.cond.L = &m.mutex - return m } -func (m *incomingItemsMap) AcceptStream() (item, error) { +func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) { m.mutex.Lock() - defer m.mutex.Unlock() var num protocol.StreamNum var str item for { num = m.nextStreamToAccept - var ok bool if m.closeErr != nil { + m.mutex.Unlock() return nil, m.closeErr } + var ok bool str, ok = m.streams[num] if ok { break } - m.cond.Wait() + m.mutex.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.newStreamChan: + } + m.mutex.Lock() } m.nextStreamToAccept++ // If this stream was completed before being accepted, we can delete it now. if _, ok := m.streamsToDelete[num]; ok { delete(m.streamsToDelete, num) if err := m.deleteStream(num); err != nil { + m.mutex.Unlock() return nil, err } } + m.mutex.Unlock() return str, nil } @@ -109,7 +115,10 @@ func (m *incomingItemsMap) GetOrOpenStream(num protocol.StreamNum) (item, error) // * highestStream is only modified by this function for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { m.streams[newNum] = m.newStream(newNum) - m.cond.Signal() + select { + case m.newStreamChan <- struct{}{}: + default: + } } m.nextStreamToOpen = num + 1 s := m.streams[num] @@ -165,5 +174,5 @@ func (m *incomingItemsMap) CloseWithError(err error) { str.closeForShutdown(err) } m.mutex.Unlock() - m.cond.Broadcast() + close(m.newStreamChan) } diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go index 0a59b7f57..f62e6d5c6 100644 --- a/streams_map_incoming_generic_test.go +++ b/streams_map_incoming_generic_test.go @@ -1,6 +1,7 @@ package quic import ( + "context" "errors" "github.com/golang/mock/gomock" @@ -66,10 +67,10 @@ var _ = Describe("Streams Map (incoming)", func() { It("accepts streams in the right order", func() { _, err := m.GetOrOpenStream(2) // open streams 1 and 2 Expect(err).ToNot(HaveOccurred()) - str, err := m.AcceptStream() + str, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) - str, err = m.AcceptStream() + str, err = m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) }) @@ -90,7 +91,7 @@ var _ = Describe("Streams Map (incoming)", func() { strChan := make(chan item) go func() { defer GinkgoRecover() - str, err := m.AcceptStream() + str, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) strChan <- str }() @@ -103,12 +104,26 @@ var _ = Describe("Streams Map (incoming)", func() { Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) }) + It("unblocks AcceptStream when the context is canceled", func() { + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.AcceptStream(ctx) + Expect(err).To(MatchError("context canceled")) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + cancel() + Eventually(done).Should(BeClosed()) + }) + It("unblocks AcceptStream when it is closed", func() { testErr := errors.New("test error") done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := m.AcceptStream() + _, err := m.AcceptStream(context.Background()) Expect(err).To(MatchError(testErr)) close(done) }() @@ -120,7 +135,7 @@ var _ = Describe("Streams Map (incoming)", func() { It("errors AcceptStream immediately if it is closed", func() { testErr := errors.New("test error") m.CloseWithError(testErr) - _, err := m.AcceptStream() + _, err := m.AcceptStream(context.Background()) Expect(err).To(MatchError(testErr)) }) @@ -141,7 +156,7 @@ var _ = Describe("Streams Map (incoming)", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) _, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) - str, err := m.AcceptStream() + str, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) Expect(m.DeleteStream(1)).To(Succeed()) @@ -154,12 +169,12 @@ var _ = Describe("Streams Map (incoming)", func() { _, err := m.GetOrOpenStream(2) Expect(err).ToNot(HaveOccurred()) Expect(m.DeleteStream(2)).To(Succeed()) - str, err := m.AcceptStream() + str, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued mockSender.EXPECT().queueControlFrame(gomock.Any()) - str, err = m.AcceptStream() + str, err = m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) }) @@ -174,7 +189,7 @@ var _ = Describe("Streams Map (incoming)", func() { Expect(str).To(BeNil()) // when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued mockSender.EXPECT().queueControlFrame(gomock.Any()) - str, err = m.AcceptStream() + str, err = m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) }) @@ -191,7 +206,7 @@ var _ = Describe("Streams Map (incoming)", func() { Expect(err).ToNot(HaveOccurred()) // accept all streams for i := 0; i < 5; i++ { - _, err := m.AcceptStream() + _, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) } mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { diff --git a/streams_map_incoming_uni.go b/streams_map_incoming_uni.go index cc83eb667..c146f3ab1 100644 --- a/streams_map_incoming_uni.go +++ b/streams_map_incoming_uni.go @@ -5,6 +5,7 @@ package quic import ( + "context" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -12,8 +13,8 @@ import ( ) type incomingUniStreamsMap struct { - mutex sync.RWMutex - cond sync.Cond + mutex sync.RWMutex + newStreamChan chan struct{} streams map[protocol.StreamNum]receiveStreamI // When a stream is deleted before it was accepted, we can't delete it immediately. @@ -36,9 +37,9 @@ func newIncomingUniStreamsMap( newStream func(protocol.StreamNum) receiveStreamI, maxStreams uint64, queueControlFrame func(wire.Frame), - // streamNumToID func(protocol.StreamNum) protocol.StreamID, ) *incomingUniStreamsMap { - m := &incomingUniStreamsMap{ + return &incomingUniStreamsMap{ + newStreamChan: make(chan struct{}), streams: make(map[protocol.StreamNum]receiveStreamI), streamsToDelete: make(map[protocol.StreamNum]struct{}), maxStream: protocol.StreamNum(maxStreams), @@ -47,38 +48,43 @@ func newIncomingUniStreamsMap( nextStreamToOpen: 1, nextStreamToAccept: 1, queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, - // streamNumToID: streamNumToID, } - m.cond.L = &m.mutex - return m } -func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) { +func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) { m.mutex.Lock() - defer m.mutex.Unlock() var num protocol.StreamNum var str receiveStreamI for { num = m.nextStreamToAccept - var ok bool if m.closeErr != nil { + m.mutex.Unlock() return nil, m.closeErr } + var ok bool str, ok = m.streams[num] if ok { break } - m.cond.Wait() + m.mutex.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.newStreamChan: + } + m.mutex.Lock() } m.nextStreamToAccept++ // If this stream was completed before being accepted, we can delete it now. if _, ok := m.streamsToDelete[num]; ok { delete(m.streamsToDelete, num) if err := m.deleteStream(num); err != nil { + m.mutex.Unlock() return nil, err } } + m.mutex.Unlock() return str, nil } @@ -111,7 +117,10 @@ func (m *incomingUniStreamsMap) GetOrOpenStream(num protocol.StreamNum) (receive // * highestStream is only modified by this function for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { m.streams[newNum] = m.newStream(newNum) - m.cond.Signal() + select { + case m.newStreamChan <- struct{}{}: + default: + } } m.nextStreamToOpen = num + 1 s := m.streams[num] @@ -167,5 +176,5 @@ func (m *incomingUniStreamsMap) CloseWithError(err error) { str.closeForShutdown(err) } m.mutex.Unlock() - m.cond.Broadcast() + close(m.newStreamChan) } diff --git a/streams_map_outgoing_bidi.go b/streams_map_outgoing_bidi.go index 6c5176f24..5c0ff71ff 100644 --- a/streams_map_outgoing_bidi.go +++ b/streams_map_outgoing_bidi.go @@ -5,6 +5,7 @@ package quic import ( + "context" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -14,10 +15,12 @@ import ( type outgoingBidiStreamsMap struct { mutex sync.RWMutex - openQueue []chan struct{} - streams map[protocol.StreamNum]streamI + openQueue map[uint64]chan struct{} + lowestInQueue uint64 + highestInQueue uint64 + nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) maxStream protocol.StreamNum // the maximum stream ID we're allowed to open blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream @@ -34,6 +37,7 @@ func newOutgoingBidiStreamsMap( ) *outgoingBidiStreamsMap { return &outgoingBidiStreamsMap{ streams: make(map[protocol.StreamNum]streamI), + openQueue: make(map[uint64]chan struct{}), maxStream: protocol.InvalidStreamNum, nextStream: 1, newStream: newStream, @@ -57,7 +61,7 @@ func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) { return m.openStream(), nil } -func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) { +func (m *outgoingBidiStreamsMap) OpenStreamSync(ctx context.Context) (streamI, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -65,17 +69,32 @@ func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) { return nil, m.closeErr } + if err := ctx.Err(); err != nil { + return nil, err + } + if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { return m.openStream(), nil } waitChan := make(chan struct{}, 1) - m.openQueue = append(m.openQueue, waitChan) + queuePos := m.highestInQueue + m.highestInQueue++ + if len(m.openQueue) == 0 { + m.lowestInQueue = queuePos + } + m.openQueue[queuePos] = waitChan m.maybeSendBlockedFrame() for { m.mutex.Unlock() - <-waitChan + select { + case <-ctx.Done(): + m.mutex.Lock() + delete(m.openQueue, queuePos) + return nil, ctx.Err() + case <-waitChan: + } m.mutex.Lock() if m.closeErr != nil { @@ -86,7 +105,7 @@ func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) { continue } str := m.openStream() - m.openQueue = m.openQueue[1:] + delete(m.openQueue, queuePos) m.unblockOpenSync() return str, nil } @@ -159,9 +178,15 @@ func (m *outgoingBidiStreamsMap) unblockOpenSync() { if len(m.openQueue) == 0 { return } - select { - case m.openQueue[0] <- struct{}{}: - default: + for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { + c, ok := m.openQueue[qp] + if !ok { // entry was deleted because the context was canceled + continue + } + close(c) + m.openQueue[qp] = nil + m.lowestInQueue = qp + 1 + return } } @@ -172,7 +197,9 @@ func (m *outgoingBidiStreamsMap) CloseWithError(err error) { str.closeForShutdown(err) } for _, c := range m.openQueue { - close(c) + if c != nil { + close(c) + } } m.mutex.Unlock() } diff --git a/streams_map_outgoing_generic.go b/streams_map_outgoing_generic.go index 1004ff933..6b6206404 100644 --- a/streams_map_outgoing_generic.go +++ b/streams_map_outgoing_generic.go @@ -1,6 +1,7 @@ package quic import ( + "context" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -12,10 +13,12 @@ import ( type outgoingItemsMap struct { mutex sync.RWMutex - openQueue []chan struct{} - streams map[protocol.StreamNum]item + openQueue map[uint64]chan struct{} + lowestInQueue uint64 + highestInQueue uint64 + nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) maxStream protocol.StreamNum // the maximum stream ID we're allowed to open blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream @@ -32,6 +35,7 @@ func newOutgoingItemsMap( ) *outgoingItemsMap { return &outgoingItemsMap{ streams: make(map[protocol.StreamNum]item), + openQueue: make(map[uint64]chan struct{}), maxStream: protocol.InvalidStreamNum, nextStream: 1, newStream: newStream, @@ -55,7 +59,7 @@ func (m *outgoingItemsMap) OpenStream() (item, error) { return m.openStream(), nil } -func (m *outgoingItemsMap) OpenStreamSync() (item, error) { +func (m *outgoingItemsMap) OpenStreamSync(ctx context.Context) (item, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -63,17 +67,32 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) { return nil, m.closeErr } + if err := ctx.Err(); err != nil { + return nil, err + } + if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { return m.openStream(), nil } waitChan := make(chan struct{}, 1) - m.openQueue = append(m.openQueue, waitChan) + queuePos := m.highestInQueue + m.highestInQueue++ + if len(m.openQueue) == 0 { + m.lowestInQueue = queuePos + } + m.openQueue[queuePos] = waitChan m.maybeSendBlockedFrame() for { m.mutex.Unlock() - <-waitChan + select { + case <-ctx.Done(): + m.mutex.Lock() + delete(m.openQueue, queuePos) + return nil, ctx.Err() + case <-waitChan: + } m.mutex.Lock() if m.closeErr != nil { @@ -84,7 +103,7 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) { continue } str := m.openStream() - m.openQueue = m.openQueue[1:] + delete(m.openQueue, queuePos) m.unblockOpenSync() return str, nil } @@ -157,9 +176,15 @@ func (m *outgoingItemsMap) unblockOpenSync() { if len(m.openQueue) == 0 { return } - select { - case m.openQueue[0] <- struct{}{}: - default: + for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { + c, ok := m.openQueue[qp] + if !ok { // entry was deleted because the context was canceled + continue + } + close(c) + m.openQueue[qp] = nil + m.lowestInQueue = qp + 1 + return } } @@ -170,7 +195,9 @@ func (m *outgoingItemsMap) CloseWithError(err error) { str.closeForShutdown(err) } for _, c := range m.openQueue { - close(c) + if c != nil { + close(c) + } } m.mutex.Unlock() } diff --git a/streams_map_outgoing_generic_test.go b/streams_map_outgoing_generic_test.go index 060145ea0..981b5b3c2 100644 --- a/streams_map_outgoing_generic_test.go +++ b/streams_map_outgoing_generic_test.go @@ -1,6 +1,7 @@ package quic import ( + "context" "errors" "github.com/golang/mock/gomock" @@ -106,12 +107,19 @@ var _ = Describe("Streams Map (outgoing)", func() { expectTooManyStreamsError(err) }) + It("returns immediately when called with a canceled context", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := m.OpenStreamSync(ctx) + Expect(err).To(MatchError("context canceled")) + }) + It("blocks until a stream can be opened synchronously", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) done := make(chan struct{}) go func() { defer GinkgoRecover() - str, err := m.OpenStreamSync() + str, err := m.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) close(done) @@ -122,12 +130,34 @@ var _ = Describe("Streams Map (outgoing)", func() { Eventually(done).Should(BeClosed()) }) + It("unblocks when the context is canceled", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.OpenStreamSync(ctx) + Expect(err).To(MatchError("context canceled")) + close(done) + }() + + Consistently(done).ShouldNot(BeClosed()) + cancel() + Eventually(done).Should(BeClosed()) + + // make sure that the next stream openend is stream 1 + m.SetMaxStream(1000) + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + }) + It("opens streams in the right order", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() done1 := make(chan struct{}) go func() { defer GinkgoRecover() - str, err := m.OpenStreamSync() + str, err := m.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) close(done1) @@ -136,7 +166,7 @@ var _ = Describe("Streams Map (outgoing)", func() { done2 := make(chan struct{}) go func() { defer GinkgoRecover() - str, err := m.OpenStreamSync() + str, err := m.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) close(done2) @@ -155,20 +185,20 @@ var _ = Describe("Streams Map (outgoing)", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := m.OpenStreamSync() + _, err := m.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) done <- struct{}{} }() go func() { defer GinkgoRecover() - _, err := m.OpenStreamSync() + _, err := m.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) done <- struct{}{} }() Consistently(done).ShouldNot(Receive()) go func() { defer GinkgoRecover() - _, err := m.OpenStreamSync() + _, err := m.OpenStreamSync(context.Background()) Expect(err).To(MatchError("test done")) done <- struct{}{} }() @@ -188,7 +218,7 @@ var _ = Describe("Streams Map (outgoing)", func() { openedSync := make(chan struct{}) go func() { defer GinkgoRecover() - str, err := m.OpenStreamSync() + str, err := m.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) close(openedSync) @@ -229,7 +259,7 @@ var _ = Describe("Streams Map (outgoing)", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := m.OpenStreamSync() + _, err := m.OpenStreamSync(context.Background()) Expect(err).To(MatchError(testErr)) close(done) }() diff --git a/streams_map_outgoing_uni.go b/streams_map_outgoing_uni.go index 841eadcea..986cb8a9b 100644 --- a/streams_map_outgoing_uni.go +++ b/streams_map_outgoing_uni.go @@ -5,6 +5,7 @@ package quic import ( + "context" "sync" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -14,10 +15,12 @@ import ( type outgoingUniStreamsMap struct { mutex sync.RWMutex - openQueue []chan struct{} - streams map[protocol.StreamNum]sendStreamI + openQueue map[uint64]chan struct{} + lowestInQueue uint64 + highestInQueue uint64 + nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) maxStream protocol.StreamNum // the maximum stream ID we're allowed to open blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream @@ -34,6 +37,7 @@ func newOutgoingUniStreamsMap( ) *outgoingUniStreamsMap { return &outgoingUniStreamsMap{ streams: make(map[protocol.StreamNum]sendStreamI), + openQueue: make(map[uint64]chan struct{}), maxStream: protocol.InvalidStreamNum, nextStream: 1, newStream: newStream, @@ -57,7 +61,7 @@ func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) { return m.openStream(), nil } -func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) { +func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -65,17 +69,32 @@ func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) { return nil, m.closeErr } + if err := ctx.Err(); err != nil { + return nil, err + } + if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { return m.openStream(), nil } waitChan := make(chan struct{}, 1) - m.openQueue = append(m.openQueue, waitChan) + queuePos := m.highestInQueue + m.highestInQueue++ + if len(m.openQueue) == 0 { + m.lowestInQueue = queuePos + } + m.openQueue[queuePos] = waitChan m.maybeSendBlockedFrame() for { m.mutex.Unlock() - <-waitChan + select { + case <-ctx.Done(): + m.mutex.Lock() + delete(m.openQueue, queuePos) + return nil, ctx.Err() + case <-waitChan: + } m.mutex.Lock() if m.closeErr != nil { @@ -86,7 +105,7 @@ func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) { continue } str := m.openStream() - m.openQueue = m.openQueue[1:] + delete(m.openQueue, queuePos) m.unblockOpenSync() return str, nil } @@ -159,9 +178,15 @@ func (m *outgoingUniStreamsMap) unblockOpenSync() { if len(m.openQueue) == 0 { return } - select { - case m.openQueue[0] <- struct{}{}: - default: + for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { + c, ok := m.openQueue[qp] + if !ok { // entry was deleted because the context was canceled + continue + } + close(c) + m.openQueue[qp] = nil + m.lowestInQueue = qp + 1 + return } } @@ -172,7 +197,9 @@ func (m *outgoingUniStreamsMap) CloseWithError(err error) { str.closeForShutdown(err) } for _, c := range m.openQueue { - close(c) + if c != nil { + close(c) + } } m.mutex.Unlock() } diff --git a/streams_map_test.go b/streams_map_test.go index cbace4669..7ab5bf7de 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -1,6 +1,7 @@ package quic import ( + "context" "errors" "fmt" "net" @@ -121,7 +122,7 @@ var _ = Describe("Streams Map", func() { It("accepts bidirectional streams", func() { _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream) Expect(err).ToNot(HaveOccurred()) - str, err := m.AcceptStream() + str, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeAssignableToTypeOf(&stream{})) Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream)) @@ -130,7 +131,7 @@ var _ = Describe("Streams Map", func() { It("accepts unidirectional streams", func() { _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream) Expect(err).ToNot(HaveOccurred()) - str, err := m.AcceptUniStream() + str, err := m.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeAssignableToTypeOf(&receiveStream{})) Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream)) @@ -170,7 +171,7 @@ var _ = Describe("Streams Map", func() { _, err := m.GetOrOpenReceiveStream(id) Expect(err).ToNot(HaveOccurred()) Expect(m.DeleteStream(id)).To(Succeed()) - str, err := m.AcceptStream() + str, err := m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) Expect(str.StreamID()).To(Equal(id)) @@ -203,7 +204,7 @@ var _ = Describe("Streams Map", func() { _, err := m.GetOrOpenReceiveStream(id) Expect(err).ToNot(HaveOccurred()) Expect(m.DeleteStream(id)).To(Succeed()) - str, err := m.AcceptUniStream() + str, err := m.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) Expect(str.StreamID()).To(Equal(id)) @@ -393,7 +394,7 @@ var _ = Describe("Streams Map", func() { It("sends a MAX_STREAMS frame for bidirectional streams", func() { _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream) Expect(err).ToNot(HaveOccurred()) - _, err = m.AcceptStream() + _, err = m.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ Type: protocol.StreamTypeBidi, @@ -405,7 +406,7 @@ var _ = Describe("Streams Map", func() { It("sends a MAX_STREAMS frame for unidirectional streams", func() { _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream) Expect(err).ToNot(HaveOccurred()) - _, err = m.AcceptUniStream() + _, err = m.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ Type: protocol.StreamTypeUni, @@ -424,10 +425,10 @@ var _ = Describe("Streams Map", func() { _, err = m.OpenUniStream() Expect(err).To(HaveOccurred()) Expect(err.Error()).To(Equal(testErr.Error())) - _, err = m.AcceptStream() + _, err = m.AcceptStream(context.Background()) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(Equal(testErr.Error())) - _, err = m.AcceptUniStream() + _, err = m.AcceptUniStream(context.Background()) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(Equal(testErr.Error())) })