Merge pull request #1952 from lucas-clemente/contexts

add contexts to all blocking functions
This commit is contained in:
Marten Seemann
2019-06-22 19:52:47 +08:00
committed by GitHub
39 changed files with 603 additions and 274 deletions

View File

@@ -8,6 +8,7 @@
- Enforce application protocol negotiation (via `tls.Config.NextProtos`).
- Use a varint for error codes.
- Add support for [quic-trace](https://github.com/google/quic-trace).
- Add a context to `Listener.Accept`, `Session.Accept{Uni}Stream` and `Session.Open{Uni}StreamSync`.
## v0.11.0 (2019-04-05)

View File

@@ -2,6 +2,7 @@ package benchmark
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
@@ -44,7 +45,7 @@ func init() {
)
Expect(err).ToNot(HaveOccurred())
serverAddr <- ln.Addr()
sess, err := ln.Accept()
sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
// wait for the client to complete the handshake before sending the data
// this should not be necessary, but due to timing issues on the CIs, this is necessary to avoid sending too many undecryptable packets
@@ -66,7 +67,7 @@ func init() {
)
Expect(err).ToNot(HaveOccurred())
close(handshakeChan)
str, err := sess.AcceptStream()
str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
buf := &bytes.Buffer{}

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
@@ -35,11 +36,11 @@ func echoServer() error {
if err != nil {
return err
}
sess, err := listener.Accept()
sess, err := listener.Accept(context.Background())
if err != nil {
return err
}
stream, err := sess.AcceptStream()
stream, err := sess.AcceptStream(context.Background())
if err != nil {
panic(err)
}
@@ -58,7 +59,7 @@ func clientMain() error {
return err
}
stream, err := session.OpenStreamSync()
stream, err := session.OpenStreamSync(context.Background())
if err != nil {
return err
}

View File

@@ -2,6 +2,7 @@ package http3
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
@@ -100,7 +101,7 @@ func (c *client) dial() error {
func (c *client) setupSession() error {
// open the control stream
str, err := c.session.OpenUniStreamSync()
str, err := c.session.OpenUniStream()
if err != nil {
return err
}
@@ -138,7 +139,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, c.handshakeErr
}
str, err := c.session.OpenStreamSync()
str, err := c.session.OpenStreamSync(context.Background())
if err != nil {
return nil, err
}

View File

@@ -3,6 +3,7 @@ package http3
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"io"
@@ -126,8 +127,8 @@ var _ = Describe("Client", func() {
testErr := errors.New("stream open error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session := mockquic.NewMockSession(mockCtrl)
session.EXPECT().OpenUniStreamSync().Return(nil, testErr).MaxTimes(1)
session.EXPECT().OpenStreamSync().Return(nil, testErr).MaxTimes(1)
session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
@@ -169,7 +170,7 @@ var _ = Describe("Client", func() {
controlStr.EXPECT().Write(gomock.Any()).MaxTimes(1) // SETTINGS frame
str = mockquic.NewMockStream(mockCtrl)
sess = mockquic.NewMockSession(mockCtrl)
sess.EXPECT().OpenUniStreamSync().Return(controlStr, nil).MaxTimes(1)
sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return sess, nil
}
@@ -179,7 +180,7 @@ var _ = Describe("Client", func() {
})
It("sends a request", func() {
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
@@ -200,7 +201,7 @@ var _ = Describe("Client", func() {
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
rw.WriteHeader(418)
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
@@ -234,7 +235,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
strBuf = &bytes.Buffer{}
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
body := &mockBody{}
body.SetData([]byte("request body"))
var err error
@@ -295,7 +296,7 @@ var _ = Describe("Client", func() {
})
It("adds the gzip header to requests", func() {
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
@@ -310,7 +311,7 @@ var _ = Describe("Client", func() {
It("doesn't add gzip if the header disable it", func() {
client = newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
@@ -324,7 +325,7 @@ var _ = Describe("Client", func() {
})
It("decompresses the response", func() {
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip")
@@ -348,7 +349,7 @@ var _ = Describe("Client", func() {
})
It("only decompresses the response if the response contains the right content-encoding header", func() {
sess.EXPECT().OpenStreamSync().Return(str, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))

View File

@@ -2,6 +2,7 @@ package http3
import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
@@ -91,8 +92,8 @@ var _ = Describe("RoundTripper", func() {
testErr := errors.New("test err")
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
session.EXPECT().OpenUniStreamSync().AnyTimes().Return(nil, testErr)
session.EXPECT().OpenStreamSync().Return(nil, testErr)
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
_, err = rt.RoundTrip(req)
Expect(err).To(MatchError(testErr))
@@ -128,8 +129,8 @@ var _ = Describe("RoundTripper", func() {
It("reuses existing clients", func() {
closed := make(chan struct{})
testErr := errors.New("test err")
session.EXPECT().OpenUniStreamSync().AnyTimes().Return(nil, testErr)
session.EXPECT().OpenStreamSync().Return(nil, testErr).Times(2)
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
Expect(err).ToNot(HaveOccurred())

View File

@@ -2,6 +2,7 @@ package http3
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
@@ -114,7 +115,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
s.listenerMutex.Unlock()
for {
sess, err := ln.Accept()
sess, err := ln.Accept(context.Background())
if err != nil {
return err
}
@@ -127,7 +128,7 @@ func (s *Server) handleConn(sess quic.Session) {
decoder := qpack.NewDecoder(nil)
// send a SETTINGS frame
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStream()
if err != nil {
s.logger.Debugf("Opening the control stream failed.")
return
@@ -137,7 +138,7 @@ func (s *Server) handleConn(sess quic.Session) {
str.Write(buf.Bytes())
for {
str, err := sess.AcceptStream()
str, err := sess.AcceptStream(context.Background())
if err != nil {
s.logger.Debugf("Accepting stream failed: %s", err)
return

View File

@@ -1,6 +1,7 @@
package self_test
import (
"context"
"fmt"
"io"
"io/ioutil"
@@ -8,6 +9,7 @@ import (
"net"
"sync"
"sync/atomic"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
@@ -33,13 +35,13 @@ var _ = Describe("Stream Cancelations", func() {
defer GinkgoRecover()
var wg sync.WaitGroup
wg.Add(numStreams)
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
for i := 0; i < numStreams; i++ {
go func() {
defer GinkgoRecover()
defer wg.Done()
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
if _, err = str.Write(testserver.PRData); err != nil {
Expect(err).To(MatchError(fmt.Sprintf("Stream %d was reset with error code %d", str.StreamID(), str.StreamID())))
@@ -75,7 +77,7 @@ var _ = Describe("Stream Cancelations", func() {
go func() {
defer GinkgoRecover()
defer wg.Done()
str, err := sess.AcceptUniStream()
str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
// cancel around 2/3 of the streams
if rand.Int31()%3 != 0 {
@@ -119,7 +121,7 @@ var _ = Describe("Stream Cancelations", func() {
go func() {
defer GinkgoRecover()
defer wg.Done()
str, err := sess.AcceptUniStream()
str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
// only read some data from about 1/3 of the streams
if rand.Int31()%3 != 0 {
@@ -167,7 +169,7 @@ var _ = Describe("Stream Cancelations", func() {
go func() {
defer GinkgoRecover()
defer wg.Done()
str, err := sess.AcceptUniStream()
str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str)
if err != nil {
@@ -196,12 +198,12 @@ var _ = Describe("Stream Cancelations", func() {
var canceledCounter int32
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
for i := 0; i < numStreams; i++ {
go func() {
defer GinkgoRecover()
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
// cancel about 2/3 of the streams
if rand.Int31()%3 != 0 {
@@ -227,12 +229,12 @@ var _ = Describe("Stream Cancelations", func() {
var canceledCounter int32
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
for i := 0; i < numStreams; i++ {
go func() {
defer GinkgoRecover()
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
// only write some data from about 1/3 of the streams, then cancel
if rand.Int31()%3 != 0 {
@@ -265,13 +267,13 @@ var _ = Describe("Stream Cancelations", func() {
defer GinkgoRecover()
var wg sync.WaitGroup
wg.Add(numStreams)
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
for i := 0; i < numStreams; i++ {
go func() {
defer GinkgoRecover()
defer wg.Done()
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
// cancel about half of the streams
if rand.Int31()%2 == 0 {
@@ -303,7 +305,7 @@ var _ = Describe("Stream Cancelations", func() {
go func() {
defer GinkgoRecover()
defer wg.Done()
str, err := sess.AcceptUniStream()
str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
// cancel around half of the streams
if rand.Int31()%2 == 0 {
@@ -339,13 +341,13 @@ var _ = Describe("Stream Cancelations", func() {
defer GinkgoRecover()
var wg sync.WaitGroup
wg.Add(numStreams)
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
for i := 0; i < numStreams; i++ {
go func() {
defer GinkgoRecover()
defer wg.Done()
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
// cancel about half of the streams
length := len(testserver.PRData)
@@ -382,7 +384,7 @@ var _ = Describe("Stream Cancelations", func() {
defer GinkgoRecover()
defer wg.Done()
str, err := sess.AcceptUniStream()
str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
r := io.Reader(str)
@@ -418,4 +420,143 @@ var _ = Describe("Stream Cancelations", func() {
Expect(server.Close()).To(Succeed())
})
})
Context("canceling the context", func() {
It("downloads data when the receiving peer cancels the context for accepting streams", func() {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
ticker := time.NewTicker(5 * time.Millisecond)
for i := 0; i < numStreams; i++ {
<-ticker.C
go func() {
defer GinkgoRecover()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(testserver.PRData)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
}
}()
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
&quic.Config{MaxIncomingUniStreams: numStreams / 3},
)
Expect(err).ToNot(HaveOccurred())
var numToAccept int32
var counter int32
var wg sync.WaitGroup
wg.Add(numStreams)
for atomic.LoadInt32(&numToAccept) < numStreams {
ctx, cancel := context.WithCancel(context.Background())
// cancel accepting half of the streams
if rand.Int31()%2 == 0 {
cancel()
} else {
atomic.AddInt32(&numToAccept, 1)
defer cancel()
}
go func() {
defer GinkgoRecover()
str, err := sess.AcceptUniStream(ctx)
if err != nil {
atomic.AddInt32(&counter, 1)
Expect(err).To(MatchError("context canceled"))
return
}
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(testserver.PRData))
wg.Done()
}()
}
wg.Wait()
count := atomic.LoadInt32(&counter)
fmt.Fprintf(GinkgoWriter, "Canceled AcceptStream %d times\n", count)
Expect(count).To(BeNumerically(">", numStreams/2))
Expect(sess.Close()).To(Succeed())
Expect(server.Close()).To(Succeed())
})
It("downloads data when the sending peer cancels the context for opening streams", func() {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
Expect(err).ToNot(HaveOccurred())
var numCanceled int32
go func() {
defer GinkgoRecover()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
var numOpened int
ticker := time.NewTicker(250 * time.Microsecond)
for numOpened < numStreams {
<-ticker.C
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// cancel accepting half of the streams
shouldCancel := rand.Int31()%2 == 0
if shouldCancel {
time.AfterFunc(5*time.Millisecond, cancel)
}
str, err := sess.OpenUniStreamSync(ctx)
if err != nil {
atomic.AddInt32(&numCanceled, 1)
Expect(err).To(MatchError("context canceled"))
continue
}
numOpened++
go func(str quic.SendStream) {
defer GinkgoRecover()
_, err = str.Write(testserver.PRData)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}(str)
}
}()
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
&quic.Config{
MaxIncomingUniStreams: 5,
},
)
Expect(err).ToNot(HaveOccurred())
var wg sync.WaitGroup
wg.Add(numStreams)
ticker := time.NewTicker(10 * time.Millisecond)
for i := 0; i < numStreams; i++ {
<-ticker.C
go func() {
defer GinkgoRecover()
str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(testserver.PRData))
wg.Done()
}()
}
wg.Wait()
count := atomic.LoadInt32(&numCanceled)
fmt.Fprintf(GinkgoWriter, "Canceled OpenStreamSync %d times\n", count)
Expect(count).To(BeNumerically(">", numStreams/5))
Expect(sess.Close()).To(Succeed())
Expect(server.Close()).To(Succeed())
})
})
})

View File

@@ -1,6 +1,7 @@
package self_test
import (
"context"
"fmt"
"io/ioutil"
"math/rand"
@@ -25,7 +26,7 @@ var _ = Describe("Connection ID lengths tests", func() {
go func() {
defer GinkgoRecover()
for {
sess, err := ln.Accept()
sess, err := ln.Accept(context.Background())
if err != nil {
return
}
@@ -51,7 +52,7 @@ var _ = Describe("Connection ID lengths tests", func() {
)
Expect(err).ToNot(HaveOccurred())
defer cl.Close()
str, err := cl.AcceptStream()
str, err := cl.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())

View File

@@ -1,6 +1,7 @@
package self_test
import (
"context"
"fmt"
"io/ioutil"
"net"
@@ -27,9 +28,9 @@ var _ = Describe("Stream deadline tests", func() {
acceptedStream := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverStr, err = sess.AcceptStream()
serverStr, err = sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
_, err = serverStr.Read([]byte{0})
Expect(err).ToNot(HaveOccurred())

View File

@@ -1,6 +1,7 @@
package self_test
import (
"context"
"fmt"
"math/rand"
"net"
@@ -88,7 +89,7 @@ var _ = Describe("Drop Tests", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
@@ -109,7 +110,7 @@ var _ = Describe("Drop Tests", func() {
)
Expect(err).ToNot(HaveOccurred())
defer sess.Close()
str, err := sess.AcceptStream()
str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
for i := uint8(1); i <= numMessages; i++ {
b := []byte{0}

View File

@@ -1,6 +1,7 @@
package self_test
import (
"context"
"fmt"
mrand "math/rand"
"net"
@@ -57,10 +58,10 @@ var _ = Describe("Handshake drop tests", func() {
serverSessionChan := make(chan quic.Session)
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
defer sess.Close()
str, err := sess.AcceptStream()
str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
@@ -92,7 +93,7 @@ var _ = Describe("Handshake drop tests", func() {
serverSessionChan := make(chan quic.Session)
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
@@ -106,7 +107,7 @@ var _ = Describe("Handshake drop tests", func() {
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream()
str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
@@ -126,7 +127,7 @@ var _ = Describe("Handshake drop tests", func() {
serverSessionChan := make(chan quic.Session)
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverSessionChan <- sess
}()

View File

@@ -1,6 +1,7 @@
package self_test
import (
"context"
"crypto/tls"
"fmt"
"net"
@@ -56,8 +57,7 @@ var _ = Describe("Handshake RTT tests", func() {
defer GinkgoRecover()
defer close(acceptStopped)
for {
_, err := server.Accept()
if err != nil {
if _, err := server.Accept(context.Background()); err != nil {
return
}
}

View File

@@ -1,6 +1,7 @@
package self_test
import (
"context"
"crypto/tls"
"fmt"
"net"
@@ -50,7 +51,7 @@ var _ = Describe("Handshake tests", func() {
defer GinkgoRecover()
defer close(acceptStopped)
for {
if _, err := server.Accept(); err != nil {
if _, err := server.Accept(context.Background()); err != nil {
return
}
}
@@ -158,7 +159,7 @@ var _ = Describe("Handshake tests", func() {
errChan := make(chan error)
go func() {
defer GinkgoRecover()
_, err := sess.AcceptStream()
_, err := sess.AcceptStream(context.Background())
errChan <- err
}()
Eventually(errChan).Should(Receive(&err))
@@ -236,7 +237,7 @@ var _ = Describe("Handshake tests", func() {
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ServerBusy))
// now accept one session, freeing one spot in the queue
_, err = server.Accept()
_, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
// dial again, and expect that this dial succeeds
sess, err := dial()
@@ -289,7 +290,7 @@ var _ = Describe("Handshake tests", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
cs := sess.ConnectionState()
Expect(cs.NegotiatedProtocol).To(Equal(alpn))

View File

@@ -1,6 +1,7 @@
package self_test
import (
"context"
"fmt"
"io/ioutil"
"net"
@@ -25,7 +26,7 @@ var _ = Describe("Multiplexing", func() {
go func() {
defer GinkgoRecover()
for {
sess, err := ln.Accept()
sess, err := ln.Accept(context.Background())
if err != nil {
return
}
@@ -50,7 +51,7 @@ var _ = Describe("Multiplexing", func() {
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream()
str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())

View File

@@ -1,6 +1,7 @@
package self_test
import (
"context"
"crypto/tls"
"fmt"
"net"
@@ -55,11 +56,11 @@ var _ = Describe("TLS session resumption", func() {
go func() {
defer close(done)
defer GinkgoRecover()
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(sess.ConnectionState().DidResume).To(BeFalse())
sess, err = server.Accept()
sess, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(sess.ConnectionState().DidResume).To(BeTrue())
}()

View File

@@ -1,6 +1,7 @@
package self_test
import (
"context"
"fmt"
"io/ioutil"
"net"
@@ -42,7 +43,7 @@ var _ = Describe("non-zero RTT", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
@@ -67,7 +68,7 @@ var _ = Describe("non-zero RTT", func() {
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream()
str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())

View File

@@ -1,6 +1,7 @@
package self_test
import (
"context"
"fmt"
"math/rand"
"net"
@@ -33,7 +34,7 @@ var _ = Describe("Stateless Resets", func() {
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
@@ -63,7 +64,7 @@ var _ = Describe("Stateless Resets", func() {
},
)
Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream()
str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data := make([]byte, 6)
_, err = str.Read(data)
@@ -86,7 +87,7 @@ var _ = Describe("Stateless Resets", func() {
acceptStopped := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := ln2.Accept()
_, err := ln2.Accept(context.Background())
Expect(err).To(HaveOccurred())
close(acceptStopped)
}()

View File

@@ -1,6 +1,7 @@
package self_test
import (
"context"
"fmt"
"io/ioutil"
"net"
@@ -46,7 +47,7 @@ var _ = Describe("Bidirectional streams", func() {
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
str, err := sess.OpenStreamSync()
str, err := sess.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
data := testserver.GeneratePRData(25 * i)
go func() {
@@ -70,7 +71,7 @@ var _ = Describe("Bidirectional streams", func() {
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
str, err := sess.AcceptStream()
str, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
@@ -92,7 +93,7 @@ var _ = Describe("Bidirectional streams", func() {
go func() {
defer GinkgoRecover()
var err error
sess, err = server.Accept()
sess, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(sess)
}()
@@ -109,7 +110,7 @@ var _ = Describe("Bidirectional streams", func() {
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
runSendingPeer(sess)
sess.Close()
@@ -129,7 +130,7 @@ var _ = Describe("Bidirectional streams", func() {
done1 := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {

View File

@@ -70,7 +70,7 @@ var _ = Describe("Timeout tests", func() {
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
@@ -95,7 +95,7 @@ var _ = Describe("Timeout tests", func() {
&quic.Config{IdleTimeout: idleTimeout},
)
Expect(err).ToNot(HaveOccurred())
strIn, err := sess.AcceptStream()
strIn, err := sess.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
strOut, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
@@ -116,9 +116,9 @@ var _ = Describe("Timeout tests", func() {
checkTimeoutError(err)
_, err = sess.OpenUniStream()
checkTimeoutError(err)
_, err = sess.AcceptStream()
_, err = sess.AcceptStream(context.Background())
checkTimeoutError(err)
_, err = sess.AcceptUniStream()
_, err = sess.AcceptUniStream(context.Background())
checkTimeoutError(err)
})
@@ -146,9 +146,9 @@ var _ = Describe("Timeout tests", func() {
serverSessionClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
sess.AcceptStream() // blocks until the session is closed
sess.AcceptStream(context.Background()) // blocks until the session is closed
close(serverSessionClosed)
}()
@@ -162,7 +162,7 @@ var _ = Describe("Timeout tests", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := sess.AcceptStream()
_, err := sess.AcceptStream(context.Background())
checkTimeoutError(err)
close(done)
}()
@@ -187,9 +187,9 @@ var _ = Describe("Timeout tests", func() {
serverSessionClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
sess.AcceptStream() // blocks until the session is closed
sess.AcceptStream(context.Background()) // blocks until the session is closed
close(serverSessionClosed)
}()
@@ -212,7 +212,7 @@ var _ = Describe("Timeout tests", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := sess.AcceptStream()
_, err := sess.AcceptStream(context.Background())
checkTimeoutError(err)
close(done)
}()

View File

@@ -1,6 +1,7 @@
package self_test
import (
"context"
"fmt"
"io/ioutil"
"net"
@@ -41,7 +42,7 @@ var _ = Describe("Unidirectional Streams", func() {
runSendingPeer := func(sess quic.Session) {
for i := 0; i < numStreams; i++ {
str, err := sess.OpenUniStreamSync()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
@@ -56,7 +57,7 @@ var _ = Describe("Unidirectional Streams", func() {
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
str, err := sess.AcceptUniStream()
str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
@@ -72,7 +73,7 @@ var _ = Describe("Unidirectional Streams", func() {
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(sess)
sess.Close()
@@ -91,7 +92,7 @@ var _ = Describe("Unidirectional Streams", func() {
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
runSendingPeer(sess)
}()
@@ -109,7 +110,7 @@ var _ = Describe("Unidirectional Streams", func() {
done1 := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {

View File

@@ -127,11 +127,11 @@ type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
// If the session was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true.
AcceptStream() (Stream, error)
AcceptStream(context.Context) (Stream, error)
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
// If the session was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true.
AcceptUniStream() (ReceiveStream, error)
AcceptUniStream(context.Context) (ReceiveStream, error)
// OpenStream opens a new bidirectional QUIC stream.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream.
@@ -143,7 +143,7 @@ type Session interface {
// It blocks until a new stream can be opened.
// If the error is non-nil, it satisfies the net.Error interface.
// If the session was closed due to a timeout, Timeout() will be true.
OpenStreamSync() (Stream, error)
OpenStreamSync(context.Context) (Stream, error)
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
// If the error is non-nil, it satisfies the net.Error interface.
// When reaching the peer's stream limit, Temporary() will be true.
@@ -153,7 +153,7 @@ type Session interface {
// It blocks until a new stream can be opened.
// If the error is non-nil, it satisfies the net.Error interface.
// If the session was closed due to a timeout, Timeout() will be true.
OpenUniStreamSync() (SendStream, error)
OpenUniStreamSync(context.Context) (SendStream, error)
// LocalAddr returns the local address.
LocalAddr() net.Addr
// RemoteAddr returns the address of the peer.
@@ -232,5 +232,5 @@ type Listener interface {
// Addr returns the local network addr that the server is listening on.
Addr() net.Addr
// Accept returns new sessions. It should be called in a loop.
Accept() (Session, error)
Accept(context.Context) (Session, error)
}

View File

@@ -39,33 +39,33 @@ func (m *MockSession) EXPECT() *MockSessionMockRecorder {
}
// AcceptStream mocks base method
func (m *MockSession) AcceptStream() (quic_go.Stream, error) {
func (m *MockSession) AcceptStream(arg0 context.Context) (quic_go.Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptStream")
ret := m.ctrl.Call(m, "AcceptStream", arg0)
ret0, _ := ret[0].(quic_go.Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcceptStream indicates an expected call of AcceptStream
func (mr *MockSessionMockRecorder) AcceptStream() *gomock.Call {
func (mr *MockSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream), arg0)
}
// AcceptUniStream mocks base method
func (m *MockSession) AcceptUniStream() (quic_go.ReceiveStream, error) {
func (m *MockSession) AcceptUniStream(arg0 context.Context) (quic_go.ReceiveStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptUniStream")
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
ret0, _ := ret[0].(quic_go.ReceiveStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcceptUniStream indicates an expected call of AcceptUniStream
func (mr *MockSessionMockRecorder) AcceptUniStream() *gomock.Call {
func (mr *MockSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream), arg0)
}
// Close mocks base method
@@ -154,18 +154,18 @@ func (mr *MockSessionMockRecorder) OpenStream() *gomock.Call {
}
// OpenStreamSync mocks base method
func (m *MockSession) OpenStreamSync() (quic_go.Stream, error) {
func (m *MockSession) OpenStreamSync(arg0 context.Context) (quic_go.Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenStreamSync")
ret := m.ctrl.Call(m, "OpenStreamSync", arg0)
ret0, _ := ret[0].(quic_go.Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStreamSync indicates an expected call of OpenStreamSync
func (mr *MockSessionMockRecorder) OpenStreamSync() *gomock.Call {
func (mr *MockSessionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockSession)(nil).OpenStreamSync))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockSession)(nil).OpenStreamSync), arg0)
}
// OpenUniStream mocks base method
@@ -184,18 +184,18 @@ func (mr *MockSessionMockRecorder) OpenUniStream() *gomock.Call {
}
// OpenUniStreamSync mocks base method
func (m *MockSession) OpenUniStreamSync() (quic_go.SendStream, error) {
func (m *MockSession) OpenUniStreamSync(arg0 context.Context) (quic_go.SendStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenUniStreamSync")
ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0)
ret0, _ := ret[0].(quic_go.SendStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenUniStreamSync indicates an expected call of OpenUniStreamSync
func (mr *MockSessionMockRecorder) OpenUniStreamSync() *gomock.Call {
func (mr *MockSessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockSession)(nil).OpenUniStreamSync))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockSession)(nil).OpenUniStreamSync), arg0)
}
// RemoteAddr mocks base method

View File

@@ -38,33 +38,33 @@ func (m *MockQuicSession) EXPECT() *MockQuicSessionMockRecorder {
}
// AcceptStream mocks base method
func (m *MockQuicSession) AcceptStream() (Stream, error) {
func (m *MockQuicSession) AcceptStream(arg0 context.Context) (Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptStream")
ret := m.ctrl.Call(m, "AcceptStream", arg0)
ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcceptStream indicates an expected call of AcceptStream
func (mr *MockQuicSessionMockRecorder) AcceptStream() *gomock.Call {
func (mr *MockQuicSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream), arg0)
}
// AcceptUniStream mocks base method
func (m *MockQuicSession) AcceptUniStream() (ReceiveStream, error) {
func (m *MockQuicSession) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptUniStream")
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
ret0, _ := ret[0].(ReceiveStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcceptUniStream indicates an expected call of AcceptUniStream
func (mr *MockQuicSessionMockRecorder) AcceptUniStream() *gomock.Call {
func (mr *MockQuicSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream), arg0)
}
// Close mocks base method
@@ -167,18 +167,18 @@ func (mr *MockQuicSessionMockRecorder) OpenStream() *gomock.Call {
}
// OpenStreamSync mocks base method
func (m *MockQuicSession) OpenStreamSync() (Stream, error) {
func (m *MockQuicSession) OpenStreamSync(arg0 context.Context) (Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenStreamSync")
ret := m.ctrl.Call(m, "OpenStreamSync", arg0)
ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStreamSync indicates an expected call of OpenStreamSync
func (mr *MockQuicSessionMockRecorder) OpenStreamSync() *gomock.Call {
func (mr *MockQuicSessionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenStreamSync))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenStreamSync), arg0)
}
// OpenUniStream mocks base method
@@ -197,18 +197,18 @@ func (mr *MockQuicSessionMockRecorder) OpenUniStream() *gomock.Call {
}
// OpenUniStreamSync mocks base method
func (m *MockQuicSession) OpenUniStreamSync() (SendStream, error) {
func (m *MockQuicSession) OpenUniStreamSync(arg0 context.Context) (SendStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenUniStreamSync")
ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0)
ret0, _ := ret[0].(SendStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenUniStreamSync indicates an expected call of OpenUniStreamSync
func (mr *MockQuicSessionMockRecorder) OpenUniStreamSync() *gomock.Call {
func (mr *MockQuicSessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenUniStreamSync))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenUniStreamSync), arg0)
}
// RemoteAddr mocks base method

View File

@@ -5,6 +5,7 @@
package quic
import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@@ -37,33 +38,33 @@ func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder {
}
// AcceptStream mocks base method
func (m *MockStreamManager) AcceptStream() (Stream, error) {
func (m *MockStreamManager) AcceptStream(arg0 context.Context) (Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptStream")
ret := m.ctrl.Call(m, "AcceptStream", arg0)
ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcceptStream indicates an expected call of AcceptStream
func (mr *MockStreamManagerMockRecorder) AcceptStream() *gomock.Call {
func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0)
}
// AcceptUniStream mocks base method
func (m *MockStreamManager) AcceptUniStream() (ReceiveStream, error) {
func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptUniStream")
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
ret0, _ := ret[0].(ReceiveStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcceptUniStream indicates an expected call of AcceptUniStream
func (mr *MockStreamManagerMockRecorder) AcceptUniStream() *gomock.Call {
func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0)
}
// CloseWithError mocks base method
@@ -152,18 +153,18 @@ func (mr *MockStreamManagerMockRecorder) OpenStream() *gomock.Call {
}
// OpenStreamSync mocks base method
func (m *MockStreamManager) OpenStreamSync() (Stream, error) {
func (m *MockStreamManager) OpenStreamSync(arg0 context.Context) (Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenStreamSync")
ret := m.ctrl.Call(m, "OpenStreamSync", arg0)
ret0, _ := ret[0].(Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStreamSync indicates an expected call of OpenStreamSync
func (mr *MockStreamManagerMockRecorder) OpenStreamSync() *gomock.Call {
func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync), arg0)
}
// OpenUniStream mocks base method
@@ -182,18 +183,18 @@ func (mr *MockStreamManagerMockRecorder) OpenUniStream() *gomock.Call {
}
// OpenUniStreamSync mocks base method
func (m *MockStreamManager) OpenUniStreamSync() (SendStream, error) {
func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (SendStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenUniStreamSync")
ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0)
ret0, _ := ret[0].(SendStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenUniStreamSync indicates an expected call of OpenUniStreamSync
func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync() *gomock.Call {
func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0)
}
// UpdateLimits mocks base method

View File

@@ -2,6 +2,7 @@ package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
@@ -284,9 +285,11 @@ func populateServerConfig(config *Config) *Config {
}
// Accept returns newly openend sessions
func (s *server) Accept() (Session, error) {
func (s *server) Accept(ctx context.Context) (Session, error) {
var sess Session
select {
case <-ctx.Done():
return nil, ctx.Err()
case sess = <-s.sessionQueue:
return sess, nil
case <-s.errorChan:

View File

@@ -434,7 +434,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.Accept()
serv.Accept(context.Background())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
@@ -461,7 +461,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept()
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
close(done)
}()
@@ -474,18 +474,33 @@ var _ = Describe("Server", func() {
testErr := errors.New("test err")
serv.setCloseError(testErr)
for i := 0; i < 3; i++ {
_, err := serv.Accept()
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
}
})
It("returns when the context is canceled", func() {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept(ctx)
Expect(err).To(MatchError("context canceled"))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
cancel()
Eventually(done).Should(BeClosed())
})
It("accepts new sessions when the handshake completes", func() {
sess := NewMockQuicSession(mockCtrl)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := serv.Accept()
s, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(s).To(Equal(sess))
close(done)

View File

@@ -37,10 +37,10 @@ type streamManager interface {
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
OpenStream() (Stream, error)
OpenUniStream() (SendStream, error)
OpenStreamSync() (Stream, error)
OpenUniStreamSync() (SendStream, error)
AcceptStream() (Stream, error)
AcceptUniStream() (ReceiveStream, error)
OpenStreamSync(context.Context) (Stream, error)
OpenUniStreamSync(context.Context) (SendStream, error)
AcceptStream(context.Context) (Stream, error)
AcceptUniStream(context.Context) (ReceiveStream, error)
DeleteStream(protocol.StreamID) error
UpdateLimits(*handshake.TransportParameters) error
HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error
@@ -1233,12 +1233,12 @@ func (s *session) logPacket(packet *packedPacket) {
}
// AcceptStream returns the next stream openend by the peer
func (s *session) AcceptStream() (Stream, error) {
return s.streamsMap.AcceptStream()
func (s *session) AcceptStream(ctx context.Context) (Stream, error) {
return s.streamsMap.AcceptStream(ctx)
}
func (s *session) AcceptUniStream() (ReceiveStream, error) {
return s.streamsMap.AcceptUniStream()
func (s *session) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
return s.streamsMap.AcceptUniStream(ctx)
}
// OpenStream opens a stream
@@ -1246,16 +1246,16 @@ func (s *session) OpenStream() (Stream, error) {
return s.streamsMap.OpenStream()
}
func (s *session) OpenStreamSync() (Stream, error) {
return s.streamsMap.OpenStreamSync()
func (s *session) OpenStreamSync(ctx context.Context) (Stream, error) {
return s.streamsMap.OpenStreamSync(ctx)
}
func (s *session) OpenUniStream() (SendStream, error) {
return s.streamsMap.OpenUniStream()
}
func (s *session) OpenUniStreamSync() (SendStream, error) {
return s.streamsMap.OpenUniStreamSync()
func (s *session) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
return s.streamsMap.OpenUniStreamSync(ctx)
}
func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController {

View File

@@ -367,14 +367,6 @@ var _ = Describe("Session", func() {
Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242)))
})
It("accepts new streams", func() {
mstr := NewMockStreamI(mockCtrl)
streamManager.EXPECT().AcceptStream().Return(mstr, nil)
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal(mstr))
})
Context("closing", func() {
var (
runErr error
@@ -1431,8 +1423,8 @@ var _ = Describe("Session", func() {
It("opens streams synchronously", func() {
mstr := NewMockStreamI(mockCtrl)
streamManager.EXPECT().OpenStreamSync().Return(mstr, nil)
str, err := sess.OpenStreamSync()
streamManager.EXPECT().OpenStreamSync(context.Background()).Return(mstr, nil)
str, err := sess.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal(mstr))
})
@@ -1447,24 +1439,28 @@ var _ = Describe("Session", func() {
It("opens unidirectional streams synchronously", func() {
mstr := NewMockSendStreamI(mockCtrl)
streamManager.EXPECT().OpenUniStreamSync().Return(mstr, nil)
str, err := sess.OpenUniStreamSync()
streamManager.EXPECT().OpenUniStreamSync(context.Background()).Return(mstr, nil)
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal(mstr))
})
It("accepts streams", func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
mstr := NewMockStreamI(mockCtrl)
streamManager.EXPECT().AcceptStream().Return(mstr, nil)
str, err := sess.AcceptStream()
streamManager.EXPECT().AcceptStream(ctx).Return(mstr, nil)
str, err := sess.AcceptStream(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal(mstr))
})
It("accepts unidirectional streams", func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
mstr := NewMockReceiveStreamI(mockCtrl)
streamManager.EXPECT().AcceptUniStream().Return(mstr, nil)
str, err := sess.AcceptUniStream()
streamManager.EXPECT().AcceptUniStream(ctx).Return(mstr, nil)
str, err := sess.AcceptUniStream(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(Equal(mstr))
})

View File

@@ -1,6 +1,7 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
@@ -108,8 +109,8 @@ func (m *streamsMap) OpenStream() (Stream, error) {
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
}
func (m *streamsMap) OpenStreamSync() (Stream, error) {
str, err := m.outgoingBidiStreams.OpenStreamSync()
func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
str, err := m.outgoingBidiStreams.OpenStreamSync(ctx)
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
}
@@ -118,18 +119,18 @@ func (m *streamsMap) OpenUniStream() (SendStream, error) {
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
}
func (m *streamsMap) OpenUniStreamSync() (SendStream, error) {
str, err := m.outgoingUniStreams.OpenStreamSync()
func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
str, err := m.outgoingUniStreams.OpenStreamSync(ctx)
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
}
func (m *streamsMap) AcceptStream() (Stream, error) {
str, err := m.incomingBidiStreams.AcceptStream()
func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
str, err := m.incomingBidiStreams.AcceptStream(ctx)
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
}
func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) {
str, err := m.incomingUniStreams.AcceptStream()
func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
str, err := m.incomingUniStreams.AcceptStream(ctx)
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
}

View File

@@ -5,6 +5,7 @@
package quic
import (
"context"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
@@ -12,8 +13,8 @@ import (
)
type incomingBidiStreamsMap struct {
mutex sync.RWMutex
cond sync.Cond
mutex sync.RWMutex
newStreamChan chan struct{}
streams map[protocol.StreamNum]streamI
// When a stream is deleted before it was accepted, we can't delete it immediately.
@@ -36,9 +37,9 @@ func newIncomingBidiStreamsMap(
newStream func(protocol.StreamNum) streamI,
maxStreams uint64,
queueControlFrame func(wire.Frame),
// streamNumToID func(protocol.StreamNum) protocol.StreamID,
) *incomingBidiStreamsMap {
m := &incomingBidiStreamsMap{
return &incomingBidiStreamsMap{
newStreamChan: make(chan struct{}),
streams: make(map[protocol.StreamNum]streamI),
streamsToDelete: make(map[protocol.StreamNum]struct{}),
maxStream: protocol.StreamNum(maxStreams),
@@ -47,38 +48,43 @@ func newIncomingBidiStreamsMap(
nextStreamToOpen: 1,
nextStreamToAccept: 1,
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
// streamNumToID: streamNumToID,
}
m.cond.L = &m.mutex
return m
}
func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) {
func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var num protocol.StreamNum
var str streamI
for {
num = m.nextStreamToAccept
var ok bool
if m.closeErr != nil {
m.mutex.Unlock()
return nil, m.closeErr
}
var ok bool
str, ok = m.streams[num]
if ok {
break
}
m.cond.Wait()
m.mutex.Unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-m.newStreamChan:
}
m.mutex.Lock()
}
m.nextStreamToAccept++
// If this stream was completed before being accepted, we can delete it now.
if _, ok := m.streamsToDelete[num]; ok {
delete(m.streamsToDelete, num)
if err := m.deleteStream(num); err != nil {
m.mutex.Unlock()
return nil, err
}
}
m.mutex.Unlock()
return str, nil
}
@@ -111,7 +117,10 @@ func (m *incomingBidiStreamsMap) GetOrOpenStream(num protocol.StreamNum) (stream
// * highestStream is only modified by this function
for newNum := m.nextStreamToOpen; newNum <= num; newNum++ {
m.streams[newNum] = m.newStream(newNum)
m.cond.Signal()
select {
case m.newStreamChan <- struct{}{}:
default:
}
}
m.nextStreamToOpen = num + 1
s := m.streams[num]
@@ -167,5 +176,5 @@ func (m *incomingBidiStreamsMap) CloseWithError(err error) {
str.closeForShutdown(err)
}
m.mutex.Unlock()
m.cond.Broadcast()
close(m.newStreamChan)
}

View File

@@ -1,6 +1,7 @@
package quic
import (
"context"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
@@ -10,8 +11,8 @@ import (
//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi"
//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni"
type incomingItemsMap struct {
mutex sync.RWMutex
cond sync.Cond
mutex sync.RWMutex
newStreamChan chan struct{}
streams map[protocol.StreamNum]item
// When a stream is deleted before it was accepted, we can't delete it immediately.
@@ -34,9 +35,9 @@ func newIncomingItemsMap(
newStream func(protocol.StreamNum) item,
maxStreams uint64,
queueControlFrame func(wire.Frame),
// streamNumToID func(protocol.StreamNum) protocol.StreamID,
) *incomingItemsMap {
m := &incomingItemsMap{
return &incomingItemsMap{
newStreamChan: make(chan struct{}),
streams: make(map[protocol.StreamNum]item),
streamsToDelete: make(map[protocol.StreamNum]struct{}),
maxStream: protocol.StreamNum(maxStreams),
@@ -45,38 +46,43 @@ func newIncomingItemsMap(
nextStreamToOpen: 1,
nextStreamToAccept: 1,
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
// streamNumToID: streamNumToID,
}
m.cond.L = &m.mutex
return m
}
func (m *incomingItemsMap) AcceptStream() (item, error) {
func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var num protocol.StreamNum
var str item
for {
num = m.nextStreamToAccept
var ok bool
if m.closeErr != nil {
m.mutex.Unlock()
return nil, m.closeErr
}
var ok bool
str, ok = m.streams[num]
if ok {
break
}
m.cond.Wait()
m.mutex.Unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-m.newStreamChan:
}
m.mutex.Lock()
}
m.nextStreamToAccept++
// If this stream was completed before being accepted, we can delete it now.
if _, ok := m.streamsToDelete[num]; ok {
delete(m.streamsToDelete, num)
if err := m.deleteStream(num); err != nil {
m.mutex.Unlock()
return nil, err
}
}
m.mutex.Unlock()
return str, nil
}
@@ -109,7 +115,10 @@ func (m *incomingItemsMap) GetOrOpenStream(num protocol.StreamNum) (item, error)
// * highestStream is only modified by this function
for newNum := m.nextStreamToOpen; newNum <= num; newNum++ {
m.streams[newNum] = m.newStream(newNum)
m.cond.Signal()
select {
case m.newStreamChan <- struct{}{}:
default:
}
}
m.nextStreamToOpen = num + 1
s := m.streams[num]
@@ -165,5 +174,5 @@ func (m *incomingItemsMap) CloseWithError(err error) {
str.closeForShutdown(err)
}
m.mutex.Unlock()
m.cond.Broadcast()
close(m.newStreamChan)
}

View File

@@ -1,6 +1,7 @@
package quic
import (
"context"
"errors"
"github.com/golang/mock/gomock"
@@ -66,10 +67,10 @@ var _ = Describe("Streams Map (incoming)", func() {
It("accepts streams in the right order", func() {
_, err := m.GetOrOpenStream(2) // open streams 1 and 2
Expect(err).ToNot(HaveOccurred())
str, err := m.AcceptStream()
str, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
str, err = m.AcceptStream()
str, err = m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
})
@@ -90,7 +91,7 @@ var _ = Describe("Streams Map (incoming)", func() {
strChan := make(chan item)
go func() {
defer GinkgoRecover()
str, err := m.AcceptStream()
str, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
strChan <- str
}()
@@ -103,12 +104,26 @@ var _ = Describe("Streams Map (incoming)", func() {
Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
})
It("unblocks AcceptStream when the context is canceled", func() {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := m.AcceptStream(ctx)
Expect(err).To(MatchError("context canceled"))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
cancel()
Eventually(done).Should(BeClosed())
})
It("unblocks AcceptStream when it is closed", func() {
testErr := errors.New("test error")
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := m.AcceptStream()
_, err := m.AcceptStream(context.Background())
Expect(err).To(MatchError(testErr))
close(done)
}()
@@ -120,7 +135,7 @@ var _ = Describe("Streams Map (incoming)", func() {
It("errors AcceptStream immediately if it is closed", func() {
testErr := errors.New("test error")
m.CloseWithError(testErr)
_, err := m.AcceptStream()
_, err := m.AcceptStream(context.Background())
Expect(err).To(MatchError(testErr))
})
@@ -141,7 +156,7 @@ var _ = Describe("Streams Map (incoming)", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
_, err := m.GetOrOpenStream(1)
Expect(err).ToNot(HaveOccurred())
str, err := m.AcceptStream()
str, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
Expect(m.DeleteStream(1)).To(Succeed())
@@ -154,12 +169,12 @@ var _ = Describe("Streams Map (incoming)", func() {
_, err := m.GetOrOpenStream(2)
Expect(err).ToNot(HaveOccurred())
Expect(m.DeleteStream(2)).To(Succeed())
str, err := m.AcceptStream()
str, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
mockSender.EXPECT().queueControlFrame(gomock.Any())
str, err = m.AcceptStream()
str, err = m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
})
@@ -174,7 +189,7 @@ var _ = Describe("Streams Map (incoming)", func() {
Expect(str).To(BeNil())
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
mockSender.EXPECT().queueControlFrame(gomock.Any())
str, err = m.AcceptStream()
str, err = m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
})
@@ -191,7 +206,7 @@ var _ = Describe("Streams Map (incoming)", func() {
Expect(err).ToNot(HaveOccurred())
// accept all streams
for i := 0; i < 5; i++ {
_, err := m.AcceptStream()
_, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
}
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {

View File

@@ -5,6 +5,7 @@
package quic
import (
"context"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
@@ -12,8 +13,8 @@ import (
)
type incomingUniStreamsMap struct {
mutex sync.RWMutex
cond sync.Cond
mutex sync.RWMutex
newStreamChan chan struct{}
streams map[protocol.StreamNum]receiveStreamI
// When a stream is deleted before it was accepted, we can't delete it immediately.
@@ -36,9 +37,9 @@ func newIncomingUniStreamsMap(
newStream func(protocol.StreamNum) receiveStreamI,
maxStreams uint64,
queueControlFrame func(wire.Frame),
// streamNumToID func(protocol.StreamNum) protocol.StreamID,
) *incomingUniStreamsMap {
m := &incomingUniStreamsMap{
return &incomingUniStreamsMap{
newStreamChan: make(chan struct{}),
streams: make(map[protocol.StreamNum]receiveStreamI),
streamsToDelete: make(map[protocol.StreamNum]struct{}),
maxStream: protocol.StreamNum(maxStreams),
@@ -47,38 +48,43 @@ func newIncomingUniStreamsMap(
nextStreamToOpen: 1,
nextStreamToAccept: 1,
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
// streamNumToID: streamNumToID,
}
m.cond.L = &m.mutex
return m
}
func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) {
func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var num protocol.StreamNum
var str receiveStreamI
for {
num = m.nextStreamToAccept
var ok bool
if m.closeErr != nil {
m.mutex.Unlock()
return nil, m.closeErr
}
var ok bool
str, ok = m.streams[num]
if ok {
break
}
m.cond.Wait()
m.mutex.Unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-m.newStreamChan:
}
m.mutex.Lock()
}
m.nextStreamToAccept++
// If this stream was completed before being accepted, we can delete it now.
if _, ok := m.streamsToDelete[num]; ok {
delete(m.streamsToDelete, num)
if err := m.deleteStream(num); err != nil {
m.mutex.Unlock()
return nil, err
}
}
m.mutex.Unlock()
return str, nil
}
@@ -111,7 +117,10 @@ func (m *incomingUniStreamsMap) GetOrOpenStream(num protocol.StreamNum) (receive
// * highestStream is only modified by this function
for newNum := m.nextStreamToOpen; newNum <= num; newNum++ {
m.streams[newNum] = m.newStream(newNum)
m.cond.Signal()
select {
case m.newStreamChan <- struct{}{}:
default:
}
}
m.nextStreamToOpen = num + 1
s := m.streams[num]
@@ -167,5 +176,5 @@ func (m *incomingUniStreamsMap) CloseWithError(err error) {
str.closeForShutdown(err)
}
m.mutex.Unlock()
m.cond.Broadcast()
close(m.newStreamChan)
}

View File

@@ -5,6 +5,7 @@
package quic
import (
"context"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
@@ -14,10 +15,12 @@ import (
type outgoingBidiStreamsMap struct {
mutex sync.RWMutex
openQueue []chan struct{}
streams map[protocol.StreamNum]streamI
openQueue map[uint64]chan struct{}
lowestInQueue uint64
highestInQueue uint64
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
@@ -34,6 +37,7 @@ func newOutgoingBidiStreamsMap(
) *outgoingBidiStreamsMap {
return &outgoingBidiStreamsMap{
streams: make(map[protocol.StreamNum]streamI),
openQueue: make(map[uint64]chan struct{}),
maxStream: protocol.InvalidStreamNum,
nextStream: 1,
newStream: newStream,
@@ -57,7 +61,7 @@ func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) {
return m.openStream(), nil
}
func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) {
func (m *outgoingBidiStreamsMap) OpenStreamSync(ctx context.Context) (streamI, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -65,17 +69,32 @@ func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) {
return nil, m.closeErr
}
if err := ctx.Err(); err != nil {
return nil, err
}
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
return m.openStream(), nil
}
waitChan := make(chan struct{}, 1)
m.openQueue = append(m.openQueue, waitChan)
queuePos := m.highestInQueue
m.highestInQueue++
if len(m.openQueue) == 0 {
m.lowestInQueue = queuePos
}
m.openQueue[queuePos] = waitChan
m.maybeSendBlockedFrame()
for {
m.mutex.Unlock()
<-waitChan
select {
case <-ctx.Done():
m.mutex.Lock()
delete(m.openQueue, queuePos)
return nil, ctx.Err()
case <-waitChan:
}
m.mutex.Lock()
if m.closeErr != nil {
@@ -86,7 +105,7 @@ func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) {
continue
}
str := m.openStream()
m.openQueue = m.openQueue[1:]
delete(m.openQueue, queuePos)
m.unblockOpenSync()
return str, nil
}
@@ -159,9 +178,15 @@ func (m *outgoingBidiStreamsMap) unblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
select {
case m.openQueue[0] <- struct{}{}:
default:
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
c, ok := m.openQueue[qp]
if !ok { // entry was deleted because the context was canceled
continue
}
close(c)
m.openQueue[qp] = nil
m.lowestInQueue = qp + 1
return
}
}
@@ -172,7 +197,9 @@ func (m *outgoingBidiStreamsMap) CloseWithError(err error) {
str.closeForShutdown(err)
}
for _, c := range m.openQueue {
close(c)
if c != nil {
close(c)
}
}
m.mutex.Unlock()
}

View File

@@ -1,6 +1,7 @@
package quic
import (
"context"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
@@ -12,10 +13,12 @@ import (
type outgoingItemsMap struct {
mutex sync.RWMutex
openQueue []chan struct{}
streams map[protocol.StreamNum]item
openQueue map[uint64]chan struct{}
lowestInQueue uint64
highestInQueue uint64
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
@@ -32,6 +35,7 @@ func newOutgoingItemsMap(
) *outgoingItemsMap {
return &outgoingItemsMap{
streams: make(map[protocol.StreamNum]item),
openQueue: make(map[uint64]chan struct{}),
maxStream: protocol.InvalidStreamNum,
nextStream: 1,
newStream: newStream,
@@ -55,7 +59,7 @@ func (m *outgoingItemsMap) OpenStream() (item, error) {
return m.openStream(), nil
}
func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
func (m *outgoingItemsMap) OpenStreamSync(ctx context.Context) (item, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -63,17 +67,32 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
return nil, m.closeErr
}
if err := ctx.Err(); err != nil {
return nil, err
}
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
return m.openStream(), nil
}
waitChan := make(chan struct{}, 1)
m.openQueue = append(m.openQueue, waitChan)
queuePos := m.highestInQueue
m.highestInQueue++
if len(m.openQueue) == 0 {
m.lowestInQueue = queuePos
}
m.openQueue[queuePos] = waitChan
m.maybeSendBlockedFrame()
for {
m.mutex.Unlock()
<-waitChan
select {
case <-ctx.Done():
m.mutex.Lock()
delete(m.openQueue, queuePos)
return nil, ctx.Err()
case <-waitChan:
}
m.mutex.Lock()
if m.closeErr != nil {
@@ -84,7 +103,7 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
continue
}
str := m.openStream()
m.openQueue = m.openQueue[1:]
delete(m.openQueue, queuePos)
m.unblockOpenSync()
return str, nil
}
@@ -157,9 +176,15 @@ func (m *outgoingItemsMap) unblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
select {
case m.openQueue[0] <- struct{}{}:
default:
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
c, ok := m.openQueue[qp]
if !ok { // entry was deleted because the context was canceled
continue
}
close(c)
m.openQueue[qp] = nil
m.lowestInQueue = qp + 1
return
}
}
@@ -170,7 +195,9 @@ func (m *outgoingItemsMap) CloseWithError(err error) {
str.closeForShutdown(err)
}
for _, c := range m.openQueue {
close(c)
if c != nil {
close(c)
}
}
m.mutex.Unlock()
}

View File

@@ -1,6 +1,7 @@
package quic
import (
"context"
"errors"
"github.com/golang/mock/gomock"
@@ -106,12 +107,19 @@ var _ = Describe("Streams Map (outgoing)", func() {
expectTooManyStreamsError(err)
})
It("returns immediately when called with a canceled context", func() {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := m.OpenStreamSync(ctx)
Expect(err).To(MatchError("context canceled"))
})
It("blocks until a stream can be opened synchronously", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
str, err := m.OpenStreamSync()
str, err := m.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
close(done)
@@ -122,12 +130,34 @@ var _ = Describe("Streams Map (outgoing)", func() {
Eventually(done).Should(BeClosed())
})
It("unblocks when the context is canceled", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := m.OpenStreamSync(ctx)
Expect(err).To(MatchError("context canceled"))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
cancel()
Eventually(done).Should(BeClosed())
// make sure that the next stream openend is stream 1
m.SetMaxStream(1000)
str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
})
It("opens streams in the right order", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
done1 := make(chan struct{})
go func() {
defer GinkgoRecover()
str, err := m.OpenStreamSync()
str, err := m.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
close(done1)
@@ -136,7 +166,7 @@ var _ = Describe("Streams Map (outgoing)", func() {
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
str, err := m.OpenStreamSync()
str, err := m.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
close(done2)
@@ -155,20 +185,20 @@ var _ = Describe("Streams Map (outgoing)", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := m.OpenStreamSync()
_, err := m.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
done <- struct{}{}
}()
go func() {
defer GinkgoRecover()
_, err := m.OpenStreamSync()
_, err := m.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
done <- struct{}{}
}()
Consistently(done).ShouldNot(Receive())
go func() {
defer GinkgoRecover()
_, err := m.OpenStreamSync()
_, err := m.OpenStreamSync(context.Background())
Expect(err).To(MatchError("test done"))
done <- struct{}{}
}()
@@ -188,7 +218,7 @@ var _ = Describe("Streams Map (outgoing)", func() {
openedSync := make(chan struct{})
go func() {
defer GinkgoRecover()
str, err := m.OpenStreamSync()
str, err := m.OpenStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
close(openedSync)
@@ -229,7 +259,7 @@ var _ = Describe("Streams Map (outgoing)", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := m.OpenStreamSync()
_, err := m.OpenStreamSync(context.Background())
Expect(err).To(MatchError(testErr))
close(done)
}()

View File

@@ -5,6 +5,7 @@
package quic
import (
"context"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
@@ -14,10 +15,12 @@ import (
type outgoingUniStreamsMap struct {
mutex sync.RWMutex
openQueue []chan struct{}
streams map[protocol.StreamNum]sendStreamI
openQueue map[uint64]chan struct{}
lowestInQueue uint64
highestInQueue uint64
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
@@ -34,6 +37,7 @@ func newOutgoingUniStreamsMap(
) *outgoingUniStreamsMap {
return &outgoingUniStreamsMap{
streams: make(map[protocol.StreamNum]sendStreamI),
openQueue: make(map[uint64]chan struct{}),
maxStream: protocol.InvalidStreamNum,
nextStream: 1,
newStream: newStream,
@@ -57,7 +61,7 @@ func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) {
return m.openStream(), nil
}
func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) {
func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -65,17 +69,32 @@ func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) {
return nil, m.closeErr
}
if err := ctx.Err(); err != nil {
return nil, err
}
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
return m.openStream(), nil
}
waitChan := make(chan struct{}, 1)
m.openQueue = append(m.openQueue, waitChan)
queuePos := m.highestInQueue
m.highestInQueue++
if len(m.openQueue) == 0 {
m.lowestInQueue = queuePos
}
m.openQueue[queuePos] = waitChan
m.maybeSendBlockedFrame()
for {
m.mutex.Unlock()
<-waitChan
select {
case <-ctx.Done():
m.mutex.Lock()
delete(m.openQueue, queuePos)
return nil, ctx.Err()
case <-waitChan:
}
m.mutex.Lock()
if m.closeErr != nil {
@@ -86,7 +105,7 @@ func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) {
continue
}
str := m.openStream()
m.openQueue = m.openQueue[1:]
delete(m.openQueue, queuePos)
m.unblockOpenSync()
return str, nil
}
@@ -159,9 +178,15 @@ func (m *outgoingUniStreamsMap) unblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
select {
case m.openQueue[0] <- struct{}{}:
default:
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
c, ok := m.openQueue[qp]
if !ok { // entry was deleted because the context was canceled
continue
}
close(c)
m.openQueue[qp] = nil
m.lowestInQueue = qp + 1
return
}
}
@@ -172,7 +197,9 @@ func (m *outgoingUniStreamsMap) CloseWithError(err error) {
str.closeForShutdown(err)
}
for _, c := range m.openQueue {
close(c)
if c != nil {
close(c)
}
}
m.mutex.Unlock()
}

View File

@@ -1,6 +1,7 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
@@ -121,7 +122,7 @@ var _ = Describe("Streams Map", func() {
It("accepts bidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
Expect(err).ToNot(HaveOccurred())
str, err := m.AcceptStream()
str, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeAssignableToTypeOf(&stream{}))
Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream))
@@ -130,7 +131,7 @@ var _ = Describe("Streams Map", func() {
It("accepts unidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
Expect(err).ToNot(HaveOccurred())
str, err := m.AcceptUniStream()
str, err := m.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeAssignableToTypeOf(&receiveStream{}))
Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream))
@@ -170,7 +171,7 @@ var _ = Describe("Streams Map", func() {
_, err := m.GetOrOpenReceiveStream(id)
Expect(err).ToNot(HaveOccurred())
Expect(m.DeleteStream(id)).To(Succeed())
str, err := m.AcceptStream()
str, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
Expect(str.StreamID()).To(Equal(id))
@@ -203,7 +204,7 @@ var _ = Describe("Streams Map", func() {
_, err := m.GetOrOpenReceiveStream(id)
Expect(err).ToNot(HaveOccurred())
Expect(m.DeleteStream(id)).To(Succeed())
str, err := m.AcceptUniStream()
str, err := m.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
Expect(str.StreamID()).To(Equal(id))
@@ -393,7 +394,7 @@ var _ = Describe("Streams Map", func() {
It("sends a MAX_STREAMS frame for bidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
Expect(err).ToNot(HaveOccurred())
_, err = m.AcceptStream()
_, err = m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
@@ -405,7 +406,7 @@ var _ = Describe("Streams Map", func() {
It("sends a MAX_STREAMS frame for unidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
Expect(err).ToNot(HaveOccurred())
_, err = m.AcceptUniStream()
_, err = m.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeUni,
@@ -424,10 +425,10 @@ var _ = Describe("Streams Map", func() {
_, err = m.OpenUniStream()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(testErr.Error()))
_, err = m.AcceptStream()
_, err = m.AcceptStream(context.Background())
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(testErr.Error()))
_, err = m.AcceptUniStream()
_, err = m.AcceptUniStream(context.Background())
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(testErr.Error()))
})