forked from quic-go/quic-go
Merge pull request #1952 from lucas-clemente/contexts
add contexts to all blocking functions
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
- Enforce application protocol negotiation (via `tls.Config.NextProtos`).
|
||||
- Use a varint for error codes.
|
||||
- Add support for [quic-trace](https://github.com/google/quic-trace).
|
||||
- Add a context to `Listener.Accept`, `Session.Accept{Uni}Stream` and `Session.Open{Uni}StreamSync`.
|
||||
|
||||
## v0.11.0 (2019-04-05)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package benchmark
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -44,7 +45,7 @@ func init() {
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAddr <- ln.Addr()
|
||||
sess, err := ln.Accept()
|
||||
sess, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// wait for the client to complete the handshake before sending the data
|
||||
// this should not be necessary, but due to timing issues on the CIs, this is necessary to avoid sending too many undecryptable packets
|
||||
@@ -66,7 +67,7 @@ func init() {
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(handshakeChan)
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
@@ -35,11 +36,11 @@ func echoServer() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sess, err := listener.Accept()
|
||||
sess, err := listener.Accept(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stream, err := sess.AcceptStream()
|
||||
stream, err := sess.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -58,7 +59,7 @@ func clientMain() error {
|
||||
return err
|
||||
}
|
||||
|
||||
stream, err := session.OpenStreamSync()
|
||||
stream, err := session.OpenStreamSync(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -100,7 +101,7 @@ func (c *client) dial() error {
|
||||
|
||||
func (c *client) setupSession() error {
|
||||
// open the control stream
|
||||
str, err := c.session.OpenUniStreamSync()
|
||||
str, err := c.session.OpenUniStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -138,7 +139,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return nil, c.handshakeErr
|
||||
}
|
||||
|
||||
str, err := c.session.OpenStreamSync()
|
||||
str, err := c.session.OpenStreamSync(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package http3
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -126,8 +127,8 @@ var _ = Describe("Client", func() {
|
||||
testErr := errors.New("stream open error")
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||
session := mockquic.NewMockSession(mockCtrl)
|
||||
session.EXPECT().OpenUniStreamSync().Return(nil, testErr).MaxTimes(1)
|
||||
session.EXPECT().OpenStreamSync().Return(nil, testErr).MaxTimes(1)
|
||||
session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1)
|
||||
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1)
|
||||
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
@@ -169,7 +170,7 @@ var _ = Describe("Client", func() {
|
||||
controlStr.EXPECT().Write(gomock.Any()).MaxTimes(1) // SETTINGS frame
|
||||
str = mockquic.NewMockStream(mockCtrl)
|
||||
sess = mockquic.NewMockSession(mockCtrl)
|
||||
sess.EXPECT().OpenUniStreamSync().Return(controlStr, nil).MaxTimes(1)
|
||||
sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1)
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return sess, nil
|
||||
}
|
||||
@@ -179,7 +180,7 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
|
||||
It("sends a request", func() {
|
||||
sess.EXPECT().OpenStreamSync().Return(str, nil)
|
||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||
buf := &bytes.Buffer{}
|
||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||
return buf.Write(p)
|
||||
@@ -200,7 +201,7 @@ var _ = Describe("Client", func() {
|
||||
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
|
||||
rw.WriteHeader(418)
|
||||
|
||||
sess.EXPECT().OpenStreamSync().Return(str, nil)
|
||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes()
|
||||
str.EXPECT().Close()
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||
@@ -234,7 +235,7 @@ var _ = Describe("Client", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
strBuf = &bytes.Buffer{}
|
||||
sess.EXPECT().OpenStreamSync().Return(str, nil)
|
||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||
body := &mockBody{}
|
||||
body.SetData([]byte("request body"))
|
||||
var err error
|
||||
@@ -295,7 +296,7 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
|
||||
It("adds the gzip header to requests", func() {
|
||||
sess.EXPECT().OpenStreamSync().Return(str, nil)
|
||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||
buf := &bytes.Buffer{}
|
||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||
return buf.Write(p)
|
||||
@@ -310,7 +311,7 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("doesn't add gzip if the header disable it", func() {
|
||||
client = newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
|
||||
sess.EXPECT().OpenStreamSync().Return(str, nil)
|
||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||
buf := &bytes.Buffer{}
|
||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||
return buf.Write(p)
|
||||
@@ -324,7 +325,7 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
|
||||
It("decompresses the response", func() {
|
||||
sess.EXPECT().OpenStreamSync().Return(str, nil)
|
||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||
buf := &bytes.Buffer{}
|
||||
rw := newResponseWriter(buf, utils.DefaultLogger)
|
||||
rw.Header().Set("Content-Encoding", "gzip")
|
||||
@@ -348,7 +349,7 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
|
||||
It("only decompresses the response if the response contains the right content-encoding header", func() {
|
||||
sess.EXPECT().OpenStreamSync().Return(str, nil)
|
||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||
buf := &bytes.Buffer{}
|
||||
rw := newResponseWriter(buf, utils.DefaultLogger)
|
||||
rw.Write([]byte("not gzipped"))
|
||||
|
||||
@@ -2,6 +2,7 @@ package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -91,8 +92,8 @@ var _ = Describe("RoundTripper", func() {
|
||||
testErr := errors.New("test err")
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
session.EXPECT().OpenUniStreamSync().AnyTimes().Return(nil, testErr)
|
||||
session.EXPECT().OpenStreamSync().Return(nil, testErr)
|
||||
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
|
||||
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
|
||||
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
|
||||
_, err = rt.RoundTrip(req)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
@@ -128,8 +129,8 @@ var _ = Describe("RoundTripper", func() {
|
||||
It("reuses existing clients", func() {
|
||||
closed := make(chan struct{})
|
||||
testErr := errors.New("test err")
|
||||
session.EXPECT().OpenUniStreamSync().AnyTimes().Return(nil, testErr)
|
||||
session.EXPECT().OpenStreamSync().Return(nil, testErr).Times(2)
|
||||
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
|
||||
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
|
||||
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -2,6 +2,7 @@ package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -114,7 +115,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
||||
s.listenerMutex.Unlock()
|
||||
|
||||
for {
|
||||
sess, err := ln.Accept()
|
||||
sess, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -127,7 +128,7 @@ func (s *Server) handleConn(sess quic.Session) {
|
||||
decoder := qpack.NewDecoder(nil)
|
||||
|
||||
// send a SETTINGS frame
|
||||
str, err := sess.OpenUniStreamSync()
|
||||
str, err := sess.OpenUniStream()
|
||||
if err != nil {
|
||||
s.logger.Debugf("Opening the control stream failed.")
|
||||
return
|
||||
@@ -137,7 +138,7 @@ func (s *Server) handleConn(sess quic.Session) {
|
||||
str.Write(buf.Bytes())
|
||||
|
||||
for {
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
s.logger.Debugf("Accepting stream failed: %s", err)
|
||||
return
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
@@ -33,13 +35,13 @@ var _ = Describe("Stream Cancelations", func() {
|
||||
defer GinkgoRecover()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
str, err := sess.OpenUniStreamSync()
|
||||
str, err := sess.OpenUniStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if _, err = str.Write(testserver.PRData); err != nil {
|
||||
Expect(err).To(MatchError(fmt.Sprintf("Stream %d was reset with error code %d", str.StreamID(), str.StreamID())))
|
||||
@@ -75,7 +77,7 @@ var _ = Describe("Stream Cancelations", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
str, err := sess.AcceptUniStream()
|
||||
str, err := sess.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// cancel around 2/3 of the streams
|
||||
if rand.Int31()%3 != 0 {
|
||||
@@ -119,7 +121,7 @@ var _ = Describe("Stream Cancelations", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
str, err := sess.AcceptUniStream()
|
||||
str, err := sess.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// only read some data from about 1/3 of the streams
|
||||
if rand.Int31()%3 != 0 {
|
||||
@@ -167,7 +169,7 @@ var _ = Describe("Stream Cancelations", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
str, err := sess.AcceptUniStream()
|
||||
str, err := sess.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(str)
|
||||
if err != nil {
|
||||
@@ -196,12 +198,12 @@ var _ = Describe("Stream Cancelations", func() {
|
||||
var canceledCounter int32
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := sess.OpenUniStreamSync()
|
||||
str, err := sess.OpenUniStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// cancel about 2/3 of the streams
|
||||
if rand.Int31()%3 != 0 {
|
||||
@@ -227,12 +229,12 @@ var _ = Describe("Stream Cancelations", func() {
|
||||
var canceledCounter int32
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := sess.OpenUniStreamSync()
|
||||
str, err := sess.OpenUniStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// only write some data from about 1/3 of the streams, then cancel
|
||||
if rand.Int31()%3 != 0 {
|
||||
@@ -265,13 +267,13 @@ var _ = Describe("Stream Cancelations", func() {
|
||||
defer GinkgoRecover()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
str, err := sess.OpenUniStreamSync()
|
||||
str, err := sess.OpenUniStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// cancel about half of the streams
|
||||
if rand.Int31()%2 == 0 {
|
||||
@@ -303,7 +305,7 @@ var _ = Describe("Stream Cancelations", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
str, err := sess.AcceptUniStream()
|
||||
str, err := sess.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// cancel around half of the streams
|
||||
if rand.Int31()%2 == 0 {
|
||||
@@ -339,13 +341,13 @@ var _ = Describe("Stream Cancelations", func() {
|
||||
defer GinkgoRecover()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
str, err := sess.OpenUniStreamSync()
|
||||
str, err := sess.OpenUniStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// cancel about half of the streams
|
||||
length := len(testserver.PRData)
|
||||
@@ -382,7 +384,7 @@ var _ = Describe("Stream Cancelations", func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
|
||||
str, err := sess.AcceptUniStream()
|
||||
str, err := sess.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
r := io.Reader(str)
|
||||
@@ -418,4 +420,143 @@ var _ = Describe("Stream Cancelations", func() {
|
||||
Expect(server.Close()).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("canceling the context", func() {
|
||||
It("downloads data when the receiving peer cancels the context for accepting streams", func() {
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ticker := time.NewTicker(5 * time.Millisecond)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
<-ticker.C
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := sess.OpenUniStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(testserver.PRData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
sess, err := quic.DialAddr(
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
&quic.Config{MaxIncomingUniStreams: numStreams / 3},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var numToAccept int32
|
||||
var counter int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for atomic.LoadInt32(&numToAccept) < numStreams {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// cancel accepting half of the streams
|
||||
if rand.Int31()%2 == 0 {
|
||||
cancel()
|
||||
} else {
|
||||
atomic.AddInt32(&numToAccept, 1)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := sess.AcceptUniStream(ctx)
|
||||
if err != nil {
|
||||
atomic.AddInt32(&counter, 1)
|
||||
Expect(err).To(MatchError("context canceled"))
|
||||
return
|
||||
}
|
||||
data, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(testserver.PRData))
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
count := atomic.LoadInt32(&counter)
|
||||
fmt.Fprintf(GinkgoWriter, "Canceled AcceptStream %d times\n", count)
|
||||
Expect(count).To(BeNumerically(">", numStreams/2))
|
||||
Expect(sess.Close()).To(Succeed())
|
||||
Expect(server.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
It("downloads data when the sending peer cancels the context for opening streams", func() {
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var numCanceled int32
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var numOpened int
|
||||
ticker := time.NewTicker(250 * time.Microsecond)
|
||||
for numOpened < numStreams {
|
||||
<-ticker.C
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
// cancel accepting half of the streams
|
||||
shouldCancel := rand.Int31()%2 == 0
|
||||
|
||||
if shouldCancel {
|
||||
time.AfterFunc(5*time.Millisecond, cancel)
|
||||
}
|
||||
str, err := sess.OpenUniStreamSync(ctx)
|
||||
if err != nil {
|
||||
atomic.AddInt32(&numCanceled, 1)
|
||||
Expect(err).To(MatchError("context canceled"))
|
||||
continue
|
||||
}
|
||||
numOpened++
|
||||
go func(str quic.SendStream) {
|
||||
defer GinkgoRecover()
|
||||
_, err = str.Write(testserver.PRData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
}(str)
|
||||
}
|
||||
}()
|
||||
|
||||
sess, err := quic.DialAddr(
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
&quic.Config{
|
||||
MaxIncomingUniStreams: 5,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
<-ticker.C
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := sess.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(testserver.PRData))
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
count := atomic.LoadInt32(&numCanceled)
|
||||
fmt.Fprintf(GinkgoWriter, "Canceled OpenStreamSync %d times\n", count)
|
||||
Expect(count).To(BeNumerically(">", numStreams/5))
|
||||
Expect(sess.Close()).To(Succeed())
|
||||
Expect(server.Close()).To(Succeed())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
@@ -25,7 +26,7 @@ var _ = Describe("Connection ID lengths tests", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for {
|
||||
sess, err := ln.Accept()
|
||||
sess, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -51,7 +52,7 @@ var _ = Describe("Connection ID lengths tests", func() {
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer cl.Close()
|
||||
str, err := cl.AcceptStream()
|
||||
str, err := cl.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@@ -27,9 +28,9 @@ var _ = Describe("Stream deadline tests", func() {
|
||||
acceptedStream := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverStr, err = sess.AcceptStream()
|
||||
serverStr, err = sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = serverStr.Read([]byte{0})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
@@ -88,7 +89,7 @@ var _ = Describe("Drop Tests", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := ln.Accept()
|
||||
sess, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -109,7 +110,7 @@ var _ = Describe("Drop Tests", func() {
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer sess.Close()
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := uint8(1); i <= numMessages; i++ {
|
||||
b := []byte{0}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
@@ -57,10 +58,10 @@ var _ = Describe("Handshake drop tests", func() {
|
||||
serverSessionChan := make(chan quic.Session)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := ln.Accept()
|
||||
sess, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer sess.Close()
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 6)
|
||||
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
|
||||
@@ -92,7 +93,7 @@ var _ = Describe("Handshake drop tests", func() {
|
||||
serverSessionChan := make(chan quic.Session)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := ln.Accept()
|
||||
sess, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -106,7 +107,7 @@ var _ = Describe("Handshake drop tests", func() {
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 6)
|
||||
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
|
||||
@@ -126,7 +127,7 @@ var _ = Describe("Handshake drop tests", func() {
|
||||
serverSessionChan := make(chan quic.Session)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := ln.Accept()
|
||||
sess, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverSessionChan <- sess
|
||||
}()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -56,8 +57,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(acceptStopped)
|
||||
for {
|
||||
_, err := server.Accept()
|
||||
if err != nil {
|
||||
if _, err := server.Accept(context.Background()); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -50,7 +51,7 @@ var _ = Describe("Handshake tests", func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(acceptStopped)
|
||||
for {
|
||||
if _, err := server.Accept(); err != nil {
|
||||
if _, err := server.Accept(context.Background()); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -158,7 +159,7 @@ var _ = Describe("Handshake tests", func() {
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := sess.AcceptStream()
|
||||
_, err := sess.AcceptStream(context.Background())
|
||||
errChan <- err
|
||||
}()
|
||||
Eventually(errChan).Should(Receive(&err))
|
||||
@@ -236,7 +237,7 @@ var _ = Describe("Handshake tests", func() {
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ServerBusy))
|
||||
|
||||
// now accept one session, freeing one spot in the queue
|
||||
_, err = server.Accept()
|
||||
_, err = server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// dial again, and expect that this dial succeeds
|
||||
sess, err := dial()
|
||||
@@ -289,7 +290,7 @@ var _ = Describe("Handshake tests", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := ln.Accept()
|
||||
sess, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cs := sess.ConnectionState()
|
||||
Expect(cs.NegotiatedProtocol).To(Equal(alpn))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@@ -25,7 +26,7 @@ var _ = Describe("Multiplexing", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for {
|
||||
sess, err := ln.Accept()
|
||||
sess, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -50,7 +51,7 @@ var _ = Describe("Multiplexing", func() {
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -55,11 +56,11 @@ var _ = Describe("TLS session resumption", func() {
|
||||
go func() {
|
||||
defer close(done)
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sess.ConnectionState().DidResume).To(BeFalse())
|
||||
|
||||
sess, err = server.Accept()
|
||||
sess, err = server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sess.ConnectionState().DidResume).To(BeTrue())
|
||||
}()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@@ -42,7 +43,7 @@ var _ = Describe("non-zero RTT", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := ln.Accept()
|
||||
sess, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -67,7 +68,7 @@ var _ = Describe("non-zero RTT", func() {
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
@@ -33,7 +34,7 @@ var _ = Describe("Stateless Resets", func() {
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := ln.Accept()
|
||||
sess, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -63,7 +64,7 @@ var _ = Describe("Stateless Resets", func() {
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data := make([]byte, 6)
|
||||
_, err = str.Read(data)
|
||||
@@ -86,7 +87,7 @@ var _ = Describe("Stateless Resets", func() {
|
||||
acceptStopped := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := ln2.Accept()
|
||||
_, err := ln2.Accept(context.Background())
|
||||
Expect(err).To(HaveOccurred())
|
||||
close(acceptStopped)
|
||||
}()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@@ -46,7 +47,7 @@ var _ = Describe("Bidirectional streams", func() {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := sess.OpenStreamSync()
|
||||
str, err := sess.OpenStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data := testserver.GeneratePRData(25 * i)
|
||||
go func() {
|
||||
@@ -70,7 +71,7 @@ var _ = Describe("Bidirectional streams", func() {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := sess.AcceptStream()
|
||||
str, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -92,7 +93,7 @@ var _ = Describe("Bidirectional streams", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
sess, err = server.Accept()
|
||||
sess, err = server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runReceivingPeer(sess)
|
||||
}()
|
||||
@@ -109,7 +110,7 @@ var _ = Describe("Bidirectional streams", func() {
|
||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(sess)
|
||||
sess.Close()
|
||||
@@ -129,7 +130,7 @@ var _ = Describe("Bidirectional streams", func() {
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
@@ -70,7 +70,7 @@ var _ = Describe("Timeout tests", func() {
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -95,7 +95,7 @@ var _ = Describe("Timeout tests", func() {
|
||||
&quic.Config{IdleTimeout: idleTimeout},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
strIn, err := sess.AcceptStream()
|
||||
strIn, err := sess.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
strOut, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -116,9 +116,9 @@ var _ = Describe("Timeout tests", func() {
|
||||
checkTimeoutError(err)
|
||||
_, err = sess.OpenUniStream()
|
||||
checkTimeoutError(err)
|
||||
_, err = sess.AcceptStream()
|
||||
_, err = sess.AcceptStream(context.Background())
|
||||
checkTimeoutError(err)
|
||||
_, err = sess.AcceptUniStream()
|
||||
_, err = sess.AcceptUniStream(context.Background())
|
||||
checkTimeoutError(err)
|
||||
})
|
||||
|
||||
@@ -146,9 +146,9 @@ var _ = Describe("Timeout tests", func() {
|
||||
serverSessionClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sess.AcceptStream() // blocks until the session is closed
|
||||
sess.AcceptStream(context.Background()) // blocks until the session is closed
|
||||
close(serverSessionClosed)
|
||||
}()
|
||||
|
||||
@@ -162,7 +162,7 @@ var _ = Describe("Timeout tests", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := sess.AcceptStream()
|
||||
_, err := sess.AcceptStream(context.Background())
|
||||
checkTimeoutError(err)
|
||||
close(done)
|
||||
}()
|
||||
@@ -187,9 +187,9 @@ var _ = Describe("Timeout tests", func() {
|
||||
serverSessionClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sess.AcceptStream() // blocks until the session is closed
|
||||
sess.AcceptStream(context.Background()) // blocks until the session is closed
|
||||
close(serverSessionClosed)
|
||||
}()
|
||||
|
||||
@@ -212,7 +212,7 @@ var _ = Describe("Timeout tests", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := sess.AcceptStream()
|
||||
_, err := sess.AcceptStream(context.Background())
|
||||
checkTimeoutError(err)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@@ -41,7 +42,7 @@ var _ = Describe("Unidirectional Streams", func() {
|
||||
|
||||
runSendingPeer := func(sess quic.Session) {
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := sess.OpenUniStreamSync()
|
||||
str, err := sess.OpenUniStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -56,7 +57,7 @@ var _ = Describe("Unidirectional Streams", func() {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := sess.AcceptUniStream()
|
||||
str, err := sess.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -72,7 +73,7 @@ var _ = Describe("Unidirectional Streams", func() {
|
||||
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runReceivingPeer(sess)
|
||||
sess.Close()
|
||||
@@ -91,7 +92,7 @@ var _ = Describe("Unidirectional Streams", func() {
|
||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(sess)
|
||||
}()
|
||||
@@ -109,7 +110,7 @@ var _ = Describe("Unidirectional Streams", func() {
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
sess, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
10
interface.go
10
interface.go
@@ -127,11 +127,11 @@ type Session interface {
|
||||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||
// If the session was closed due to a timeout, the error satisfies
|
||||
// the net.Error interface, and Timeout() will be true.
|
||||
AcceptStream() (Stream, error)
|
||||
AcceptStream(context.Context) (Stream, error)
|
||||
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
|
||||
// If the session was closed due to a timeout, the error satisfies
|
||||
// the net.Error interface, and Timeout() will be true.
|
||||
AcceptUniStream() (ReceiveStream, error)
|
||||
AcceptUniStream(context.Context) (ReceiveStream, error)
|
||||
// OpenStream opens a new bidirectional QUIC stream.
|
||||
// There is no signaling to the peer about new streams:
|
||||
// The peer can only accept the stream after data has been sent on the stream.
|
||||
@@ -143,7 +143,7 @@ type Session interface {
|
||||
// It blocks until a new stream can be opened.
|
||||
// If the error is non-nil, it satisfies the net.Error interface.
|
||||
// If the session was closed due to a timeout, Timeout() will be true.
|
||||
OpenStreamSync() (Stream, error)
|
||||
OpenStreamSync(context.Context) (Stream, error)
|
||||
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
|
||||
// If the error is non-nil, it satisfies the net.Error interface.
|
||||
// When reaching the peer's stream limit, Temporary() will be true.
|
||||
@@ -153,7 +153,7 @@ type Session interface {
|
||||
// It blocks until a new stream can be opened.
|
||||
// If the error is non-nil, it satisfies the net.Error interface.
|
||||
// If the session was closed due to a timeout, Timeout() will be true.
|
||||
OpenUniStreamSync() (SendStream, error)
|
||||
OpenUniStreamSync(context.Context) (SendStream, error)
|
||||
// LocalAddr returns the local address.
|
||||
LocalAddr() net.Addr
|
||||
// RemoteAddr returns the address of the peer.
|
||||
@@ -232,5 +232,5 @@ type Listener interface {
|
||||
// Addr returns the local network addr that the server is listening on.
|
||||
Addr() net.Addr
|
||||
// Accept returns new sessions. It should be called in a loop.
|
||||
Accept() (Session, error)
|
||||
Accept(context.Context) (Session, error)
|
||||
}
|
||||
|
||||
@@ -39,33 +39,33 @@ func (m *MockSession) EXPECT() *MockSessionMockRecorder {
|
||||
}
|
||||
|
||||
// AcceptStream mocks base method
|
||||
func (m *MockSession) AcceptStream() (quic_go.Stream, error) {
|
||||
func (m *MockSession) AcceptStream(arg0 context.Context) (quic_go.Stream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptStream")
|
||||
ret := m.ctrl.Call(m, "AcceptStream", arg0)
|
||||
ret0, _ := ret[0].(quic_go.Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptStream indicates an expected call of AcceptStream
|
||||
func (mr *MockSessionMockRecorder) AcceptStream() *gomock.Call {
|
||||
func (mr *MockSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream), arg0)
|
||||
}
|
||||
|
||||
// AcceptUniStream mocks base method
|
||||
func (m *MockSession) AcceptUniStream() (quic_go.ReceiveStream, error) {
|
||||
func (m *MockSession) AcceptUniStream(arg0 context.Context) (quic_go.ReceiveStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptUniStream")
|
||||
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
|
||||
ret0, _ := ret[0].(quic_go.ReceiveStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptUniStream indicates an expected call of AcceptUniStream
|
||||
func (mr *MockSessionMockRecorder) AcceptUniStream() *gomock.Call {
|
||||
func (mr *MockSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream), arg0)
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
@@ -154,18 +154,18 @@ func (mr *MockSessionMockRecorder) OpenStream() *gomock.Call {
|
||||
}
|
||||
|
||||
// OpenStreamSync mocks base method
|
||||
func (m *MockSession) OpenStreamSync() (quic_go.Stream, error) {
|
||||
func (m *MockSession) OpenStreamSync(arg0 context.Context) (quic_go.Stream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OpenStreamSync")
|
||||
ret := m.ctrl.Call(m, "OpenStreamSync", arg0)
|
||||
ret0, _ := ret[0].(quic_go.Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenStreamSync indicates an expected call of OpenStreamSync
|
||||
func (mr *MockSessionMockRecorder) OpenStreamSync() *gomock.Call {
|
||||
func (mr *MockSessionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockSession)(nil).OpenStreamSync))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockSession)(nil).OpenStreamSync), arg0)
|
||||
}
|
||||
|
||||
// OpenUniStream mocks base method
|
||||
@@ -184,18 +184,18 @@ func (mr *MockSessionMockRecorder) OpenUniStream() *gomock.Call {
|
||||
}
|
||||
|
||||
// OpenUniStreamSync mocks base method
|
||||
func (m *MockSession) OpenUniStreamSync() (quic_go.SendStream, error) {
|
||||
func (m *MockSession) OpenUniStreamSync(arg0 context.Context) (quic_go.SendStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OpenUniStreamSync")
|
||||
ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0)
|
||||
ret0, _ := ret[0].(quic_go.SendStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenUniStreamSync indicates an expected call of OpenUniStreamSync
|
||||
func (mr *MockSessionMockRecorder) OpenUniStreamSync() *gomock.Call {
|
||||
func (mr *MockSessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockSession)(nil).OpenUniStreamSync))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockSession)(nil).OpenUniStreamSync), arg0)
|
||||
}
|
||||
|
||||
// RemoteAddr mocks base method
|
||||
|
||||
@@ -38,33 +38,33 @@ func (m *MockQuicSession) EXPECT() *MockQuicSessionMockRecorder {
|
||||
}
|
||||
|
||||
// AcceptStream mocks base method
|
||||
func (m *MockQuicSession) AcceptStream() (Stream, error) {
|
||||
func (m *MockQuicSession) AcceptStream(arg0 context.Context) (Stream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptStream")
|
||||
ret := m.ctrl.Call(m, "AcceptStream", arg0)
|
||||
ret0, _ := ret[0].(Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptStream indicates an expected call of AcceptStream
|
||||
func (mr *MockQuicSessionMockRecorder) AcceptStream() *gomock.Call {
|
||||
func (mr *MockQuicSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream), arg0)
|
||||
}
|
||||
|
||||
// AcceptUniStream mocks base method
|
||||
func (m *MockQuicSession) AcceptUniStream() (ReceiveStream, error) {
|
||||
func (m *MockQuicSession) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptUniStream")
|
||||
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
|
||||
ret0, _ := ret[0].(ReceiveStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptUniStream indicates an expected call of AcceptUniStream
|
||||
func (mr *MockQuicSessionMockRecorder) AcceptUniStream() *gomock.Call {
|
||||
func (mr *MockQuicSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream), arg0)
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
@@ -167,18 +167,18 @@ func (mr *MockQuicSessionMockRecorder) OpenStream() *gomock.Call {
|
||||
}
|
||||
|
||||
// OpenStreamSync mocks base method
|
||||
func (m *MockQuicSession) OpenStreamSync() (Stream, error) {
|
||||
func (m *MockQuicSession) OpenStreamSync(arg0 context.Context) (Stream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OpenStreamSync")
|
||||
ret := m.ctrl.Call(m, "OpenStreamSync", arg0)
|
||||
ret0, _ := ret[0].(Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenStreamSync indicates an expected call of OpenStreamSync
|
||||
func (mr *MockQuicSessionMockRecorder) OpenStreamSync() *gomock.Call {
|
||||
func (mr *MockQuicSessionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenStreamSync))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenStreamSync), arg0)
|
||||
}
|
||||
|
||||
// OpenUniStream mocks base method
|
||||
@@ -197,18 +197,18 @@ func (mr *MockQuicSessionMockRecorder) OpenUniStream() *gomock.Call {
|
||||
}
|
||||
|
||||
// OpenUniStreamSync mocks base method
|
||||
func (m *MockQuicSession) OpenUniStreamSync() (SendStream, error) {
|
||||
func (m *MockQuicSession) OpenUniStreamSync(arg0 context.Context) (SendStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OpenUniStreamSync")
|
||||
ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0)
|
||||
ret0, _ := ret[0].(SendStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenUniStreamSync indicates an expected call of OpenUniStreamSync
|
||||
func (mr *MockQuicSessionMockRecorder) OpenUniStreamSync() *gomock.Call {
|
||||
func (mr *MockQuicSessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenUniStreamSync))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenUniStreamSync), arg0)
|
||||
}
|
||||
|
||||
// RemoteAddr mocks base method
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@@ -37,33 +38,33 @@ func (m *MockStreamManager) EXPECT() *MockStreamManagerMockRecorder {
|
||||
}
|
||||
|
||||
// AcceptStream mocks base method
|
||||
func (m *MockStreamManager) AcceptStream() (Stream, error) {
|
||||
func (m *MockStreamManager) AcceptStream(arg0 context.Context) (Stream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptStream")
|
||||
ret := m.ctrl.Call(m, "AcceptStream", arg0)
|
||||
ret0, _ := ret[0].(Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptStream indicates an expected call of AcceptStream
|
||||
func (mr *MockStreamManagerMockRecorder) AcceptStream() *gomock.Call {
|
||||
func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0)
|
||||
}
|
||||
|
||||
// AcceptUniStream mocks base method
|
||||
func (m *MockStreamManager) AcceptUniStream() (ReceiveStream, error) {
|
||||
func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptUniStream")
|
||||
ret := m.ctrl.Call(m, "AcceptUniStream", arg0)
|
||||
ret0, _ := ret[0].(ReceiveStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcceptUniStream indicates an expected call of AcceptUniStream
|
||||
func (mr *MockStreamManagerMockRecorder) AcceptUniStream() *gomock.Call {
|
||||
func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0)
|
||||
}
|
||||
|
||||
// CloseWithError mocks base method
|
||||
@@ -152,18 +153,18 @@ func (mr *MockStreamManagerMockRecorder) OpenStream() *gomock.Call {
|
||||
}
|
||||
|
||||
// OpenStreamSync mocks base method
|
||||
func (m *MockStreamManager) OpenStreamSync() (Stream, error) {
|
||||
func (m *MockStreamManager) OpenStreamSync(arg0 context.Context) (Stream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OpenStreamSync")
|
||||
ret := m.ctrl.Call(m, "OpenStreamSync", arg0)
|
||||
ret0, _ := ret[0].(Stream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenStreamSync indicates an expected call of OpenStreamSync
|
||||
func (mr *MockStreamManagerMockRecorder) OpenStreamSync() *gomock.Call {
|
||||
func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync), arg0)
|
||||
}
|
||||
|
||||
// OpenUniStream mocks base method
|
||||
@@ -182,18 +183,18 @@ func (mr *MockStreamManagerMockRecorder) OpenUniStream() *gomock.Call {
|
||||
}
|
||||
|
||||
// OpenUniStreamSync mocks base method
|
||||
func (m *MockStreamManager) OpenUniStreamSync() (SendStream, error) {
|
||||
func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (SendStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OpenUniStreamSync")
|
||||
ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0)
|
||||
ret0, _ := ret[0].(SendStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenUniStreamSync indicates an expected call of OpenUniStreamSync
|
||||
func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync() *gomock.Call {
|
||||
func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0)
|
||||
}
|
||||
|
||||
// UpdateLimits mocks base method
|
||||
|
||||
@@ -2,6 +2,7 @@ package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -284,9 +285,11 @@ func populateServerConfig(config *Config) *Config {
|
||||
}
|
||||
|
||||
// Accept returns newly openend sessions
|
||||
func (s *server) Accept() (Session, error) {
|
||||
func (s *server) Accept(ctx context.Context) (Session, error) {
|
||||
var sess Session
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case sess = <-s.sessionQueue:
|
||||
return sess, nil
|
||||
case <-s.errorChan:
|
||||
|
||||
@@ -434,7 +434,7 @@ var _ = Describe("Server", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serv.Accept()
|
||||
serv.Accept(context.Background())
|
||||
close(done)
|
||||
}()
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
@@ -461,7 +461,7 @@ var _ = Describe("Server", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := serv.Accept()
|
||||
_, err := serv.Accept(context.Background())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
@@ -474,18 +474,33 @@ var _ = Describe("Server", func() {
|
||||
testErr := errors.New("test err")
|
||||
serv.setCloseError(testErr)
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := serv.Accept()
|
||||
_, err := serv.Accept(context.Background())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
}
|
||||
})
|
||||
|
||||
It("returns when the context is canceled", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := serv.Accept(ctx)
|
||||
Expect(err).To(MatchError("context canceled"))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
cancel()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("accepts new sessions when the handshake completes", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
s, err := serv.Accept()
|
||||
s, err := serv.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).To(Equal(sess))
|
||||
close(done)
|
||||
|
||||
24
session.go
24
session.go
@@ -37,10 +37,10 @@ type streamManager interface {
|
||||
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
|
||||
OpenStream() (Stream, error)
|
||||
OpenUniStream() (SendStream, error)
|
||||
OpenStreamSync() (Stream, error)
|
||||
OpenUniStreamSync() (SendStream, error)
|
||||
AcceptStream() (Stream, error)
|
||||
AcceptUniStream() (ReceiveStream, error)
|
||||
OpenStreamSync(context.Context) (Stream, error)
|
||||
OpenUniStreamSync(context.Context) (SendStream, error)
|
||||
AcceptStream(context.Context) (Stream, error)
|
||||
AcceptUniStream(context.Context) (ReceiveStream, error)
|
||||
DeleteStream(protocol.StreamID) error
|
||||
UpdateLimits(*handshake.TransportParameters) error
|
||||
HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error
|
||||
@@ -1233,12 +1233,12 @@ func (s *session) logPacket(packet *packedPacket) {
|
||||
}
|
||||
|
||||
// AcceptStream returns the next stream openend by the peer
|
||||
func (s *session) AcceptStream() (Stream, error) {
|
||||
return s.streamsMap.AcceptStream()
|
||||
func (s *session) AcceptStream(ctx context.Context) (Stream, error) {
|
||||
return s.streamsMap.AcceptStream(ctx)
|
||||
}
|
||||
|
||||
func (s *session) AcceptUniStream() (ReceiveStream, error) {
|
||||
return s.streamsMap.AcceptUniStream()
|
||||
func (s *session) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
|
||||
return s.streamsMap.AcceptUniStream(ctx)
|
||||
}
|
||||
|
||||
// OpenStream opens a stream
|
||||
@@ -1246,16 +1246,16 @@ func (s *session) OpenStream() (Stream, error) {
|
||||
return s.streamsMap.OpenStream()
|
||||
}
|
||||
|
||||
func (s *session) OpenStreamSync() (Stream, error) {
|
||||
return s.streamsMap.OpenStreamSync()
|
||||
func (s *session) OpenStreamSync(ctx context.Context) (Stream, error) {
|
||||
return s.streamsMap.OpenStreamSync(ctx)
|
||||
}
|
||||
|
||||
func (s *session) OpenUniStream() (SendStream, error) {
|
||||
return s.streamsMap.OpenUniStream()
|
||||
}
|
||||
|
||||
func (s *session) OpenUniStreamSync() (SendStream, error) {
|
||||
return s.streamsMap.OpenUniStreamSync()
|
||||
func (s *session) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
|
||||
return s.streamsMap.OpenUniStreamSync(ctx)
|
||||
}
|
||||
|
||||
func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController {
|
||||
|
||||
@@ -367,14 +367,6 @@ var _ = Describe("Session", func() {
|
||||
Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242)))
|
||||
})
|
||||
|
||||
It("accepts new streams", func() {
|
||||
mstr := NewMockStreamI(mockCtrl)
|
||||
streamManager.EXPECT().AcceptStream().Return(mstr, nil)
|
||||
str, err := sess.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal(mstr))
|
||||
})
|
||||
|
||||
Context("closing", func() {
|
||||
var (
|
||||
runErr error
|
||||
@@ -1431,8 +1423,8 @@ var _ = Describe("Session", func() {
|
||||
|
||||
It("opens streams synchronously", func() {
|
||||
mstr := NewMockStreamI(mockCtrl)
|
||||
streamManager.EXPECT().OpenStreamSync().Return(mstr, nil)
|
||||
str, err := sess.OpenStreamSync()
|
||||
streamManager.EXPECT().OpenStreamSync(context.Background()).Return(mstr, nil)
|
||||
str, err := sess.OpenStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal(mstr))
|
||||
})
|
||||
@@ -1447,24 +1439,28 @@ var _ = Describe("Session", func() {
|
||||
|
||||
It("opens unidirectional streams synchronously", func() {
|
||||
mstr := NewMockSendStreamI(mockCtrl)
|
||||
streamManager.EXPECT().OpenUniStreamSync().Return(mstr, nil)
|
||||
str, err := sess.OpenUniStreamSync()
|
||||
streamManager.EXPECT().OpenUniStreamSync(context.Background()).Return(mstr, nil)
|
||||
str, err := sess.OpenUniStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal(mstr))
|
||||
})
|
||||
|
||||
It("accepts streams", func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
mstr := NewMockStreamI(mockCtrl)
|
||||
streamManager.EXPECT().AcceptStream().Return(mstr, nil)
|
||||
str, err := sess.AcceptStream()
|
||||
streamManager.EXPECT().AcceptStream(ctx).Return(mstr, nil)
|
||||
str, err := sess.AcceptStream(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal(mstr))
|
||||
})
|
||||
|
||||
It("accepts unidirectional streams", func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
mstr := NewMockReceiveStreamI(mockCtrl)
|
||||
streamManager.EXPECT().AcceptUniStream().Return(mstr, nil)
|
||||
str, err := sess.AcceptUniStream()
|
||||
streamManager.EXPECT().AcceptUniStream(ctx).Return(mstr, nil)
|
||||
str, err := sess.AcceptUniStream(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(Equal(mstr))
|
||||
})
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -108,8 +109,8 @@ func (m *streamsMap) OpenStream() (Stream, error) {
|
||||
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
|
||||
}
|
||||
|
||||
func (m *streamsMap) OpenStreamSync() (Stream, error) {
|
||||
str, err := m.outgoingBidiStreams.OpenStreamSync()
|
||||
func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
|
||||
str, err := m.outgoingBidiStreams.OpenStreamSync(ctx)
|
||||
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
|
||||
}
|
||||
|
||||
@@ -118,18 +119,18 @@ func (m *streamsMap) OpenUniStream() (SendStream, error) {
|
||||
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
|
||||
}
|
||||
|
||||
func (m *streamsMap) OpenUniStreamSync() (SendStream, error) {
|
||||
str, err := m.outgoingUniStreams.OpenStreamSync()
|
||||
func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
|
||||
str, err := m.outgoingUniStreams.OpenStreamSync(ctx)
|
||||
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
|
||||
}
|
||||
|
||||
func (m *streamsMap) AcceptStream() (Stream, error) {
|
||||
str, err := m.incomingBidiStreams.AcceptStream()
|
||||
func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
|
||||
str, err := m.incomingBidiStreams.AcceptStream(ctx)
|
||||
return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
|
||||
}
|
||||
|
||||
func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) {
|
||||
str, err := m.incomingUniStreams.AcceptStream()
|
||||
func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
|
||||
str, err := m.incomingUniStreams.AcceptStream(ctx)
|
||||
return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
@@ -12,8 +13,8 @@ import (
|
||||
)
|
||||
|
||||
type incomingBidiStreamsMap struct {
|
||||
mutex sync.RWMutex
|
||||
cond sync.Cond
|
||||
mutex sync.RWMutex
|
||||
newStreamChan chan struct{}
|
||||
|
||||
streams map[protocol.StreamNum]streamI
|
||||
// When a stream is deleted before it was accepted, we can't delete it immediately.
|
||||
@@ -36,9 +37,9 @@ func newIncomingBidiStreamsMap(
|
||||
newStream func(protocol.StreamNum) streamI,
|
||||
maxStreams uint64,
|
||||
queueControlFrame func(wire.Frame),
|
||||
// streamNumToID func(protocol.StreamNum) protocol.StreamID,
|
||||
) *incomingBidiStreamsMap {
|
||||
m := &incomingBidiStreamsMap{
|
||||
return &incomingBidiStreamsMap{
|
||||
newStreamChan: make(chan struct{}),
|
||||
streams: make(map[protocol.StreamNum]streamI),
|
||||
streamsToDelete: make(map[protocol.StreamNum]struct{}),
|
||||
maxStream: protocol.StreamNum(maxStreams),
|
||||
@@ -47,38 +48,43 @@ func newIncomingBidiStreamsMap(
|
||||
nextStreamToOpen: 1,
|
||||
nextStreamToAccept: 1,
|
||||
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
|
||||
// streamNumToID: streamNumToID,
|
||||
}
|
||||
m.cond.L = &m.mutex
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) {
|
||||
func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var num protocol.StreamNum
|
||||
var str streamI
|
||||
for {
|
||||
num = m.nextStreamToAccept
|
||||
var ok bool
|
||||
if m.closeErr != nil {
|
||||
m.mutex.Unlock()
|
||||
return nil, m.closeErr
|
||||
}
|
||||
var ok bool
|
||||
str, ok = m.streams[num]
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
m.cond.Wait()
|
||||
m.mutex.Unlock()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-m.newStreamChan:
|
||||
}
|
||||
m.mutex.Lock()
|
||||
}
|
||||
m.nextStreamToAccept++
|
||||
// If this stream was completed before being accepted, we can delete it now.
|
||||
if _, ok := m.streamsToDelete[num]; ok {
|
||||
delete(m.streamsToDelete, num)
|
||||
if err := m.deleteStream(num); err != nil {
|
||||
m.mutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return str, nil
|
||||
}
|
||||
|
||||
@@ -111,7 +117,10 @@ func (m *incomingBidiStreamsMap) GetOrOpenStream(num protocol.StreamNum) (stream
|
||||
// * highestStream is only modified by this function
|
||||
for newNum := m.nextStreamToOpen; newNum <= num; newNum++ {
|
||||
m.streams[newNum] = m.newStream(newNum)
|
||||
m.cond.Signal()
|
||||
select {
|
||||
case m.newStreamChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
m.nextStreamToOpen = num + 1
|
||||
s := m.streams[num]
|
||||
@@ -167,5 +176,5 @@ func (m *incomingBidiStreamsMap) CloseWithError(err error) {
|
||||
str.closeForShutdown(err)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
m.cond.Broadcast()
|
||||
close(m.newStreamChan)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
@@ -10,8 +11,8 @@ import (
|
||||
//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi"
|
||||
//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni"
|
||||
type incomingItemsMap struct {
|
||||
mutex sync.RWMutex
|
||||
cond sync.Cond
|
||||
mutex sync.RWMutex
|
||||
newStreamChan chan struct{}
|
||||
|
||||
streams map[protocol.StreamNum]item
|
||||
// When a stream is deleted before it was accepted, we can't delete it immediately.
|
||||
@@ -34,9 +35,9 @@ func newIncomingItemsMap(
|
||||
newStream func(protocol.StreamNum) item,
|
||||
maxStreams uint64,
|
||||
queueControlFrame func(wire.Frame),
|
||||
// streamNumToID func(protocol.StreamNum) protocol.StreamID,
|
||||
) *incomingItemsMap {
|
||||
m := &incomingItemsMap{
|
||||
return &incomingItemsMap{
|
||||
newStreamChan: make(chan struct{}),
|
||||
streams: make(map[protocol.StreamNum]item),
|
||||
streamsToDelete: make(map[protocol.StreamNum]struct{}),
|
||||
maxStream: protocol.StreamNum(maxStreams),
|
||||
@@ -45,38 +46,43 @@ func newIncomingItemsMap(
|
||||
nextStreamToOpen: 1,
|
||||
nextStreamToAccept: 1,
|
||||
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
|
||||
// streamNumToID: streamNumToID,
|
||||
}
|
||||
m.cond.L = &m.mutex
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *incomingItemsMap) AcceptStream() (item, error) {
|
||||
func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var num protocol.StreamNum
|
||||
var str item
|
||||
for {
|
||||
num = m.nextStreamToAccept
|
||||
var ok bool
|
||||
if m.closeErr != nil {
|
||||
m.mutex.Unlock()
|
||||
return nil, m.closeErr
|
||||
}
|
||||
var ok bool
|
||||
str, ok = m.streams[num]
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
m.cond.Wait()
|
||||
m.mutex.Unlock()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-m.newStreamChan:
|
||||
}
|
||||
m.mutex.Lock()
|
||||
}
|
||||
m.nextStreamToAccept++
|
||||
// If this stream was completed before being accepted, we can delete it now.
|
||||
if _, ok := m.streamsToDelete[num]; ok {
|
||||
delete(m.streamsToDelete, num)
|
||||
if err := m.deleteStream(num); err != nil {
|
||||
m.mutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return str, nil
|
||||
}
|
||||
|
||||
@@ -109,7 +115,10 @@ func (m *incomingItemsMap) GetOrOpenStream(num protocol.StreamNum) (item, error)
|
||||
// * highestStream is only modified by this function
|
||||
for newNum := m.nextStreamToOpen; newNum <= num; newNum++ {
|
||||
m.streams[newNum] = m.newStream(newNum)
|
||||
m.cond.Signal()
|
||||
select {
|
||||
case m.newStreamChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
m.nextStreamToOpen = num + 1
|
||||
s := m.streams[num]
|
||||
@@ -165,5 +174,5 @@ func (m *incomingItemsMap) CloseWithError(err error) {
|
||||
str.closeForShutdown(err)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
m.cond.Broadcast()
|
||||
close(m.newStreamChan)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
@@ -66,10 +67,10 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||
It("accepts streams in the right order", func() {
|
||||
_, err := m.GetOrOpenStream(2) // open streams 1 and 2
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.AcceptStream()
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
str, err = m.AcceptStream()
|
||||
str, err = m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
|
||||
})
|
||||
@@ -90,7 +91,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||
strChan := make(chan item)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := m.AcceptStream()
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
strChan <- str
|
||||
}()
|
||||
@@ -103,12 +104,26 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||
Expect(acceptedStr.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
})
|
||||
|
||||
It("unblocks AcceptStream when the context is canceled", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := m.AcceptStream(ctx)
|
||||
Expect(err).To(MatchError("context canceled"))
|
||||
close(done)
|
||||
}()
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
cancel()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("unblocks AcceptStream when it is closed", func() {
|
||||
testErr := errors.New("test error")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := m.AcceptStream()
|
||||
_, err := m.AcceptStream(context.Background())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
@@ -120,7 +135,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||
It("errors AcceptStream immediately if it is closed", func() {
|
||||
testErr := errors.New("test error")
|
||||
m.CloseWithError(testErr)
|
||||
_, err := m.AcceptStream()
|
||||
_, err := m.AcceptStream(context.Background())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
@@ -141,7 +156,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
_, err := m.GetOrOpenStream(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.AcceptStream()
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
Expect(m.DeleteStream(1)).To(Succeed())
|
||||
@@ -154,12 +169,12 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||
_, err := m.GetOrOpenStream(2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.DeleteStream(2)).To(Succeed())
|
||||
str, err := m.AcceptStream()
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
str, err = m.AcceptStream()
|
||||
str, err = m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
|
||||
})
|
||||
@@ -174,7 +189,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||
Expect(str).To(BeNil())
|
||||
// when accepting this stream, it will get deleted, and a MAX_STREAMS frame is queued
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
str, err = m.AcceptStream()
|
||||
str, err = m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
})
|
||||
@@ -191,7 +206,7 @@ var _ = Describe("Streams Map (incoming)", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// accept all streams
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err := m.AcceptStream()
|
||||
_, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
@@ -12,8 +13,8 @@ import (
|
||||
)
|
||||
|
||||
type incomingUniStreamsMap struct {
|
||||
mutex sync.RWMutex
|
||||
cond sync.Cond
|
||||
mutex sync.RWMutex
|
||||
newStreamChan chan struct{}
|
||||
|
||||
streams map[protocol.StreamNum]receiveStreamI
|
||||
// When a stream is deleted before it was accepted, we can't delete it immediately.
|
||||
@@ -36,9 +37,9 @@ func newIncomingUniStreamsMap(
|
||||
newStream func(protocol.StreamNum) receiveStreamI,
|
||||
maxStreams uint64,
|
||||
queueControlFrame func(wire.Frame),
|
||||
// streamNumToID func(protocol.StreamNum) protocol.StreamID,
|
||||
) *incomingUniStreamsMap {
|
||||
m := &incomingUniStreamsMap{
|
||||
return &incomingUniStreamsMap{
|
||||
newStreamChan: make(chan struct{}),
|
||||
streams: make(map[protocol.StreamNum]receiveStreamI),
|
||||
streamsToDelete: make(map[protocol.StreamNum]struct{}),
|
||||
maxStream: protocol.StreamNum(maxStreams),
|
||||
@@ -47,38 +48,43 @@ func newIncomingUniStreamsMap(
|
||||
nextStreamToOpen: 1,
|
||||
nextStreamToAccept: 1,
|
||||
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
|
||||
// streamNumToID: streamNumToID,
|
||||
}
|
||||
m.cond.L = &m.mutex
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) {
|
||||
func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var num protocol.StreamNum
|
||||
var str receiveStreamI
|
||||
for {
|
||||
num = m.nextStreamToAccept
|
||||
var ok bool
|
||||
if m.closeErr != nil {
|
||||
m.mutex.Unlock()
|
||||
return nil, m.closeErr
|
||||
}
|
||||
var ok bool
|
||||
str, ok = m.streams[num]
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
m.cond.Wait()
|
||||
m.mutex.Unlock()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-m.newStreamChan:
|
||||
}
|
||||
m.mutex.Lock()
|
||||
}
|
||||
m.nextStreamToAccept++
|
||||
// If this stream was completed before being accepted, we can delete it now.
|
||||
if _, ok := m.streamsToDelete[num]; ok {
|
||||
delete(m.streamsToDelete, num)
|
||||
if err := m.deleteStream(num); err != nil {
|
||||
m.mutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return str, nil
|
||||
}
|
||||
|
||||
@@ -111,7 +117,10 @@ func (m *incomingUniStreamsMap) GetOrOpenStream(num protocol.StreamNum) (receive
|
||||
// * highestStream is only modified by this function
|
||||
for newNum := m.nextStreamToOpen; newNum <= num; newNum++ {
|
||||
m.streams[newNum] = m.newStream(newNum)
|
||||
m.cond.Signal()
|
||||
select {
|
||||
case m.newStreamChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
m.nextStreamToOpen = num + 1
|
||||
s := m.streams[num]
|
||||
@@ -167,5 +176,5 @@ func (m *incomingUniStreamsMap) CloseWithError(err error) {
|
||||
str.closeForShutdown(err)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
m.cond.Broadcast()
|
||||
close(m.newStreamChan)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
@@ -14,10 +15,12 @@ import (
|
||||
type outgoingBidiStreamsMap struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
openQueue []chan struct{}
|
||||
|
||||
streams map[protocol.StreamNum]streamI
|
||||
|
||||
openQueue map[uint64]chan struct{}
|
||||
lowestInQueue uint64
|
||||
highestInQueue uint64
|
||||
|
||||
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
|
||||
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
|
||||
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
|
||||
@@ -34,6 +37,7 @@ func newOutgoingBidiStreamsMap(
|
||||
) *outgoingBidiStreamsMap {
|
||||
return &outgoingBidiStreamsMap{
|
||||
streams: make(map[protocol.StreamNum]streamI),
|
||||
openQueue: make(map[uint64]chan struct{}),
|
||||
maxStream: protocol.InvalidStreamNum,
|
||||
nextStream: 1,
|
||||
newStream: newStream,
|
||||
@@ -57,7 +61,7 @@ func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) {
|
||||
return m.openStream(), nil
|
||||
}
|
||||
|
||||
func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) {
|
||||
func (m *outgoingBidiStreamsMap) OpenStreamSync(ctx context.Context) (streamI, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -65,17 +69,32 @@ func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) {
|
||||
return nil, m.closeErr
|
||||
}
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
|
||||
return m.openStream(), nil
|
||||
}
|
||||
|
||||
waitChan := make(chan struct{}, 1)
|
||||
m.openQueue = append(m.openQueue, waitChan)
|
||||
queuePos := m.highestInQueue
|
||||
m.highestInQueue++
|
||||
if len(m.openQueue) == 0 {
|
||||
m.lowestInQueue = queuePos
|
||||
}
|
||||
m.openQueue[queuePos] = waitChan
|
||||
m.maybeSendBlockedFrame()
|
||||
|
||||
for {
|
||||
m.mutex.Unlock()
|
||||
<-waitChan
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.mutex.Lock()
|
||||
delete(m.openQueue, queuePos)
|
||||
return nil, ctx.Err()
|
||||
case <-waitChan:
|
||||
}
|
||||
m.mutex.Lock()
|
||||
|
||||
if m.closeErr != nil {
|
||||
@@ -86,7 +105,7 @@ func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) {
|
||||
continue
|
||||
}
|
||||
str := m.openStream()
|
||||
m.openQueue = m.openQueue[1:]
|
||||
delete(m.openQueue, queuePos)
|
||||
m.unblockOpenSync()
|
||||
return str, nil
|
||||
}
|
||||
@@ -159,9 +178,15 @@ func (m *outgoingBidiStreamsMap) unblockOpenSync() {
|
||||
if len(m.openQueue) == 0 {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case m.openQueue[0] <- struct{}{}:
|
||||
default:
|
||||
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
|
||||
c, ok := m.openQueue[qp]
|
||||
if !ok { // entry was deleted because the context was canceled
|
||||
continue
|
||||
}
|
||||
close(c)
|
||||
m.openQueue[qp] = nil
|
||||
m.lowestInQueue = qp + 1
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,7 +197,9 @@ func (m *outgoingBidiStreamsMap) CloseWithError(err error) {
|
||||
str.closeForShutdown(err)
|
||||
}
|
||||
for _, c := range m.openQueue {
|
||||
close(c)
|
||||
if c != nil {
|
||||
close(c)
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
@@ -12,10 +13,12 @@ import (
|
||||
type outgoingItemsMap struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
openQueue []chan struct{}
|
||||
|
||||
streams map[protocol.StreamNum]item
|
||||
|
||||
openQueue map[uint64]chan struct{}
|
||||
lowestInQueue uint64
|
||||
highestInQueue uint64
|
||||
|
||||
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
|
||||
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
|
||||
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
|
||||
@@ -32,6 +35,7 @@ func newOutgoingItemsMap(
|
||||
) *outgoingItemsMap {
|
||||
return &outgoingItemsMap{
|
||||
streams: make(map[protocol.StreamNum]item),
|
||||
openQueue: make(map[uint64]chan struct{}),
|
||||
maxStream: protocol.InvalidStreamNum,
|
||||
nextStream: 1,
|
||||
newStream: newStream,
|
||||
@@ -55,7 +59,7 @@ func (m *outgoingItemsMap) OpenStream() (item, error) {
|
||||
return m.openStream(), nil
|
||||
}
|
||||
|
||||
func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
|
||||
func (m *outgoingItemsMap) OpenStreamSync(ctx context.Context) (item, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -63,17 +67,32 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
|
||||
return nil, m.closeErr
|
||||
}
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
|
||||
return m.openStream(), nil
|
||||
}
|
||||
|
||||
waitChan := make(chan struct{}, 1)
|
||||
m.openQueue = append(m.openQueue, waitChan)
|
||||
queuePos := m.highestInQueue
|
||||
m.highestInQueue++
|
||||
if len(m.openQueue) == 0 {
|
||||
m.lowestInQueue = queuePos
|
||||
}
|
||||
m.openQueue[queuePos] = waitChan
|
||||
m.maybeSendBlockedFrame()
|
||||
|
||||
for {
|
||||
m.mutex.Unlock()
|
||||
<-waitChan
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.mutex.Lock()
|
||||
delete(m.openQueue, queuePos)
|
||||
return nil, ctx.Err()
|
||||
case <-waitChan:
|
||||
}
|
||||
m.mutex.Lock()
|
||||
|
||||
if m.closeErr != nil {
|
||||
@@ -84,7 +103,7 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) {
|
||||
continue
|
||||
}
|
||||
str := m.openStream()
|
||||
m.openQueue = m.openQueue[1:]
|
||||
delete(m.openQueue, queuePos)
|
||||
m.unblockOpenSync()
|
||||
return str, nil
|
||||
}
|
||||
@@ -157,9 +176,15 @@ func (m *outgoingItemsMap) unblockOpenSync() {
|
||||
if len(m.openQueue) == 0 {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case m.openQueue[0] <- struct{}{}:
|
||||
default:
|
||||
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
|
||||
c, ok := m.openQueue[qp]
|
||||
if !ok { // entry was deleted because the context was canceled
|
||||
continue
|
||||
}
|
||||
close(c)
|
||||
m.openQueue[qp] = nil
|
||||
m.lowestInQueue = qp + 1
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,7 +195,9 @@ func (m *outgoingItemsMap) CloseWithError(err error) {
|
||||
str.closeForShutdown(err)
|
||||
}
|
||||
for _, c := range m.openQueue {
|
||||
close(c)
|
||||
if c != nil {
|
||||
close(c)
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
@@ -106,12 +107,19 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
||||
expectTooManyStreamsError(err)
|
||||
})
|
||||
|
||||
It("returns immediately when called with a canceled context", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err := m.OpenStreamSync(ctx)
|
||||
Expect(err).To(MatchError("context canceled"))
|
||||
})
|
||||
|
||||
It("blocks until a stream can be opened synchronously", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := m.OpenStreamSync()
|
||||
str, err := m.OpenStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
close(done)
|
||||
@@ -122,12 +130,34 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("unblocks when the context is canceled", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := m.OpenStreamSync(ctx)
|
||||
Expect(err).To(MatchError("context canceled"))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
cancel()
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
||||
// make sure that the next stream openend is stream 1
|
||||
m.SetMaxStream(1000)
|
||||
str, err := m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
})
|
||||
|
||||
It("opens streams in the right order", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := m.OpenStreamSync()
|
||||
str, err := m.OpenStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
close(done1)
|
||||
@@ -136,7 +166,7 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := m.OpenStreamSync()
|
||||
str, err := m.OpenStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2)))
|
||||
close(done2)
|
||||
@@ -155,20 +185,20 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := m.OpenStreamSync()
|
||||
_, err := m.OpenStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := m.OpenStreamSync()
|
||||
_, err := m.OpenStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done <- struct{}{}
|
||||
}()
|
||||
Consistently(done).ShouldNot(Receive())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := m.OpenStreamSync()
|
||||
_, err := m.OpenStreamSync(context.Background())
|
||||
Expect(err).To(MatchError("test done"))
|
||||
done <- struct{}{}
|
||||
}()
|
||||
@@ -188,7 +218,7 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
||||
openedSync := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := m.OpenStreamSync()
|
||||
str, err := m.OpenStreamSync(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1)))
|
||||
close(openedSync)
|
||||
@@ -229,7 +259,7 @@ var _ = Describe("Streams Map (outgoing)", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := m.OpenStreamSync()
|
||||
_, err := m.OpenStreamSync(context.Background())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
@@ -14,10 +15,12 @@ import (
|
||||
type outgoingUniStreamsMap struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
openQueue []chan struct{}
|
||||
|
||||
streams map[protocol.StreamNum]sendStreamI
|
||||
|
||||
openQueue map[uint64]chan struct{}
|
||||
lowestInQueue uint64
|
||||
highestInQueue uint64
|
||||
|
||||
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
|
||||
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
|
||||
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
|
||||
@@ -34,6 +37,7 @@ func newOutgoingUniStreamsMap(
|
||||
) *outgoingUniStreamsMap {
|
||||
return &outgoingUniStreamsMap{
|
||||
streams: make(map[protocol.StreamNum]sendStreamI),
|
||||
openQueue: make(map[uint64]chan struct{}),
|
||||
maxStream: protocol.InvalidStreamNum,
|
||||
nextStream: 1,
|
||||
newStream: newStream,
|
||||
@@ -57,7 +61,7 @@ func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) {
|
||||
return m.openStream(), nil
|
||||
}
|
||||
|
||||
func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) {
|
||||
func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -65,17 +69,32 @@ func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) {
|
||||
return nil, m.closeErr
|
||||
}
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
|
||||
return m.openStream(), nil
|
||||
}
|
||||
|
||||
waitChan := make(chan struct{}, 1)
|
||||
m.openQueue = append(m.openQueue, waitChan)
|
||||
queuePos := m.highestInQueue
|
||||
m.highestInQueue++
|
||||
if len(m.openQueue) == 0 {
|
||||
m.lowestInQueue = queuePos
|
||||
}
|
||||
m.openQueue[queuePos] = waitChan
|
||||
m.maybeSendBlockedFrame()
|
||||
|
||||
for {
|
||||
m.mutex.Unlock()
|
||||
<-waitChan
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.mutex.Lock()
|
||||
delete(m.openQueue, queuePos)
|
||||
return nil, ctx.Err()
|
||||
case <-waitChan:
|
||||
}
|
||||
m.mutex.Lock()
|
||||
|
||||
if m.closeErr != nil {
|
||||
@@ -86,7 +105,7 @@ func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) {
|
||||
continue
|
||||
}
|
||||
str := m.openStream()
|
||||
m.openQueue = m.openQueue[1:]
|
||||
delete(m.openQueue, queuePos)
|
||||
m.unblockOpenSync()
|
||||
return str, nil
|
||||
}
|
||||
@@ -159,9 +178,15 @@ func (m *outgoingUniStreamsMap) unblockOpenSync() {
|
||||
if len(m.openQueue) == 0 {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case m.openQueue[0] <- struct{}{}:
|
||||
default:
|
||||
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
|
||||
c, ok := m.openQueue[qp]
|
||||
if !ok { // entry was deleted because the context was canceled
|
||||
continue
|
||||
}
|
||||
close(c)
|
||||
m.openQueue[qp] = nil
|
||||
m.lowestInQueue = qp + 1
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,7 +197,9 @@ func (m *outgoingUniStreamsMap) CloseWithError(err error) {
|
||||
str.closeForShutdown(err)
|
||||
}
|
||||
for _, c := range m.openQueue {
|
||||
close(c)
|
||||
if c != nil {
|
||||
close(c)
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -121,7 +122,7 @@ var _ = Describe("Streams Map", func() {
|
||||
It("accepts bidirectional streams", func() {
|
||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.AcceptStream()
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeAssignableToTypeOf(&stream{}))
|
||||
Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream))
|
||||
@@ -130,7 +131,7 @@ var _ = Describe("Streams Map", func() {
|
||||
It("accepts unidirectional streams", func() {
|
||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := m.AcceptUniStream()
|
||||
str, err := m.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeAssignableToTypeOf(&receiveStream{}))
|
||||
Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream))
|
||||
@@ -170,7 +171,7 @@ var _ = Describe("Streams Map", func() {
|
||||
_, err := m.GetOrOpenReceiveStream(id)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.DeleteStream(id)).To(Succeed())
|
||||
str, err := m.AcceptStream()
|
||||
str, err := m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
Expect(str.StreamID()).To(Equal(id))
|
||||
@@ -203,7 +204,7 @@ var _ = Describe("Streams Map", func() {
|
||||
_, err := m.GetOrOpenReceiveStream(id)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.DeleteStream(id)).To(Succeed())
|
||||
str, err := m.AcceptUniStream()
|
||||
str, err := m.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
Expect(str.StreamID()).To(Equal(id))
|
||||
@@ -393,7 +394,7 @@ var _ = Describe("Streams Map", func() {
|
||||
It("sends a MAX_STREAMS frame for bidirectional streams", func() {
|
||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = m.AcceptStream()
|
||||
_, err = m.AcceptStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
|
||||
Type: protocol.StreamTypeBidi,
|
||||
@@ -405,7 +406,7 @@ var _ = Describe("Streams Map", func() {
|
||||
It("sends a MAX_STREAMS frame for unidirectional streams", func() {
|
||||
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = m.AcceptUniStream()
|
||||
_, err = m.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
|
||||
Type: protocol.StreamTypeUni,
|
||||
@@ -424,10 +425,10 @@ var _ = Describe("Streams Map", func() {
|
||||
_, err = m.OpenUniStream()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal(testErr.Error()))
|
||||
_, err = m.AcceptStream()
|
||||
_, err = m.AcceptStream(context.Background())
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal(testErr.Error()))
|
||||
_, err = m.AcceptUniStream()
|
||||
_, err = m.AcceptUniStream(context.Background())
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal(testErr.Error()))
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user