diff --git a/h2quic/client_test.go b/h2quic/client_test.go index db037c108..83151ff21 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -31,12 +31,24 @@ var _ = Describe("Client", func() { origDialAddr = dialAddr ) + injectResponse := func(id protocol.StreamID, rsp *http.Response) { + EventuallyWithOffset(0, func() bool { + client.mutex.Lock() + defer client.mutex.Unlock() + _, ok := client.responses[id] + return ok + }).Should(BeTrue()) + rspChan := client.responses[5] + ExpectWithOffset(0, rspChan).ToNot(BeClosed()) + rspChan <- rsp + } + BeforeEach(func() { origDialAddr = dialAddr hostname := "quic.clemente.io:1337" client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil) Expect(client.hostname).To(Equal(hostname)) - session = &mockSession{} + session = newMockSession() session.ctx, session.ctxCancel = context.WithCancel(context.Background()) client.session = session @@ -88,10 +100,11 @@ var _ = Describe("Client", func() { _, err := client.RoundTrip(req) Expect(err).ToNot(HaveOccurred()) close(done) + // fmt.Println("done") }() Eventually(func() quic.Session { return client.session }).Should(Equal(session)) // make the go routine return - client.responses[5] <- &http.Response{} + injectResponse(5, &http.Response{}) Eventually(done).Should(BeClosed()) }) @@ -126,9 +139,8 @@ var _ = Describe("Client", func() { Expect(qCfg).To(Equal(client.config)) Expect(tlsCfg).To(Equal(client.tlsConf)) // make the go routine return - client.responses[5] <- &http.Response{} + injectResponse(5, &http.Response{}) Eventually(done).Should(BeClosed()) - }) It("errors if it can't open a stream", func() { @@ -150,13 +162,15 @@ var _ = Describe("Client", func() { request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) - var doErr error + done := make(chan struct{}) go func() { - _, doErr = client.RoundTrip(request) + _, err := client.RoundTrip(request) + Expect(err).To(MatchError(testErr)) + close(done) }() _, err = client.RoundTrip(request) Expect(err).To(MatchError(testErr)) - Eventually(func() error { return doErr }).Should(MatchError(testErr)) + Eventually(done).Should(BeClosed()) }) Context("Doing requests", func() { @@ -194,32 +208,27 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) }) - It("does a request", func(done Done) { - var doRsp *http.Response - var doErr error - var doReturned bool - go func() { - doRsp, doErr = client.RoundTrip(request) - doReturned = true - }() - - Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty()) - Eventually(func() map[protocol.StreamID]chan *http.Response { return client.responses }).Should(HaveKey(protocol.StreamID(5))) - rsp := &http.Response{ + It("does a request", func() { + teapot := &http.Response{ Status: "418 I'm a teapot", StatusCode: 418, } - Expect(client.responses[5]).ToNot(BeClosed()) - Expect(client.headerErrored).ToNot(BeClosed()) - client.responses[5] <- rsp - Eventually(func() bool { return doReturned }).Should(BeTrue()) - Expect(doErr).ToNot(HaveOccurred()) - Expect(doRsp).To(Equal(rsp)) - Expect(doRsp.Body).To(Equal(dataStream)) - Expect(doRsp.ContentLength).To(BeEquivalentTo(-1)) - Expect(doRsp.Request).To(Equal(request)) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + rsp, err := client.RoundTrip(request) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp).To(Equal(teapot)) + Expect(rsp.Body).To(Equal(dataStream)) + Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) + Expect(rsp.Request).To(Equal(request)) + close(done) + }() - close(done) + Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty()) + injectResponse(5, teapot) + Expect(client.headerErrored).ToNot(BeClosed()) + Eventually(done).Should(BeClosed()) }) It("closes the quic client when encountering an error on the header stream", func() { @@ -250,18 +259,21 @@ var _ = Describe("Client", func() { }) It("blocks if no stream is available", func() { - session.streamsToOpen = []quic.Stream{headerStream} + session.streamsToOpen = []quic.Stream{headerStream, dataStream} session.blockOpenStreamSync = true - var doReturned bool + done := make(chan struct{}) go func() { defer GinkgoRecover() _, err := client.RoundTrip(request) Expect(err).ToNot(HaveOccurred()) - doReturned = true + close(done) }() - go client.handleHeaderStream() - Consistently(func() bool { return doReturned }).Should(BeFalse()) + Consistently(done).ShouldNot(BeClosed()) + // make the go routine return + client.Close() + injectResponse(5, &http.Response{}) + Eventually(done).Should(BeClosed()) }) Context("validating the address", func() { @@ -279,39 +291,56 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError("quic http2: unsupported scheme")) }) - It("adds the port for request URLs without one", func(done Done) { - var err error + It("adds the port for request URLs without one", func() { client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) - var doErr error - var doReturned bool + done := make(chan struct{}) // the client.RoundTrip will block, because the encryption level is still set to Unencrypted go func() { - _, doErr = client.RoundTrip(req) - doReturned = true + defer GinkgoRecover() + _, err := client.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + close(done) }() - Consistently(doReturned).Should(BeFalse()) - Expect(doErr).ToNot(HaveOccurred()) - close(done) + Consistently(done).ShouldNot(BeClosed()) + // make the go routine return + injectResponse(5, &http.Response{}) + Eventually(done).Should(BeClosed()) }) }) It("sets the EndStream header for requests without a body", func() { - go func() { client.RoundTrip(request) }() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + client.RoundTrip(request) + close(done) + }() Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil()) mhf := getRequest(headerStream.dataWritten.Bytes()) Expect(mhf.HeadersFrame.StreamEnded()).To(BeTrue()) + // make the go routine return + injectResponse(5, &http.Response{}) + Eventually(done).Should(BeClosed()) }) It("sets the EndStream header to false for requests with a body", func() { request.Body = &mockBody{} - go func() { client.RoundTrip(request) }() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + client.RoundTrip(request) + close(done) + }() Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil()) mhf := getRequest(headerStream.dataWritten.Bytes()) Expect(mhf.HeadersFrame.StreamEnded()).To(BeFalse()) + // make the go routine return + injectResponse(5, &http.Response{}) + Eventually(done).Should(BeClosed()) }) Context("requests containing a Body", func() { @@ -333,38 +362,33 @@ var _ = Describe("Client", func() { }) It("sends a request", func() { - var doRsp *http.Response - var doErr error - var doReturned bool + rspChan := make(chan *http.Response) go func() { defer GinkgoRecover() - doRsp, doErr = client.RoundTrip(request) - Expect(doErr).ToNot(HaveOccurred()) - doReturned = true + rsp, err := client.RoundTrip(request) + Expect(err).ToNot(HaveOccurred()) + rspChan <- rsp }() - Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) - client.responses[5] <- response - Eventually(func() bool { return doReturned }).Should(BeTrue()) + injectResponse(5, response) + Eventually(rspChan).Should(Receive(Equal(response))) Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody)) Expect(dataStream.closed).To(BeTrue()) Expect(request.Body.(*mockBody).closed).To(BeTrue()) - Expect(doRsp).To(Equal(response)) }) It("returns the error that occurred when reading the body", func() { testErr := errors.New("testErr") request.Body.(*mockBody).readErr = testErr - var doRsp *http.Response - var doErr error - var doReturned bool + done := make(chan struct{}) go func() { - doRsp, doErr = client.RoundTrip(request) - doReturned = true + defer GinkgoRecover() + rsp, err := client.RoundTrip(request) + Expect(err).To(MatchError(testErr)) + Expect(rsp).To(BeNil()) + close(done) }() - Eventually(func() bool { return doReturned }).Should(BeTrue()) - Expect(doErr).To(MatchError(testErr)) - Expect(doRsp).To(BeNil()) + Eventually(done).Should(BeClosed()) Expect(request.Body.(*mockBody).closed).To(BeTrue()) }) @@ -372,16 +396,15 @@ var _ = Describe("Client", func() { testErr := errors.New("testErr") request.Body.(*mockBody).closeErr = testErr - var doRsp *http.Response - var doErr error - var doReturned bool + done := make(chan struct{}) go func() { - doRsp, doErr = client.RoundTrip(request) - doReturned = true + defer GinkgoRecover() + rsp, err := client.RoundTrip(request) + Expect(err).To(MatchError(testErr)) + Expect(rsp).To(BeNil()) + close(done) }() - Eventually(func() bool { return doReturned }).Should(BeTrue()) - Expect(doErr).To(MatchError(testErr)) - Expect(doRsp).To(BeNil()) + Eventually(done).Should(BeClosed()) Expect(request.Body.(*mockBody).closed).To(BeTrue()) }) }) @@ -402,77 +425,91 @@ var _ = Describe("Client", func() { } }) - It("adds the gzip header to requests", func(done Done) { - var doRsp *http.Response - var doErr error - go func() { doRsp, doErr = client.RoundTrip(request) }() + It("adds the gzip header to requests", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + rsp, err := client.RoundTrip(request) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp).ToNot(BeNil()) + Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) + Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) + Expect(rsp.Header.Get("Content-Length")).To(BeEmpty()) + data := make([]byte, 6) + _, err = io.ReadFull(rsp.Body, data) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + close(done) + }() - Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) dataStream.dataToRead.Write(gzippedData) response.Header.Add("Content-Encoding", "gzip") - client.responses[5] <- response - Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil()) - Expect(doErr).ToNot(HaveOccurred()) + injectResponse(5, response) headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes())) Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip")) - Expect(doRsp.ContentLength).To(BeEquivalentTo(-1)) - Expect(doRsp.Header.Get("Content-Encoding")).To(BeEmpty()) - Expect(doRsp.Header.Get("Content-Length")).To(BeEmpty()) close(dataStream.unblockRead) - data := make([]byte, 6) - _, err := io.ReadFull(doRsp.Body, data) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal([]byte("foobar"))) - close(done) - }, 2) + Eventually(done).Should(BeClosed()) + }) It("doesn't add gzip if the header disable it", func() { client.opts.DisableCompression = true - var doErr error - go func() { _, doErr = client.RoundTrip(request) }() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := client.RoundTrip(request) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() - Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) - Expect(doErr).ToNot(HaveOccurred()) Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty()) headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes())) Expect(headers).ToNot(HaveKey("accept-encoding")) + // make the go routine return + injectResponse(5, &http.Response{}) + Eventually(done).Should(BeClosed()) }) It("only decompresses the response if the response contains the right content-encoding header", func() { - var doRsp *http.Response - var doErr error - go func() { doRsp, doErr = client.RoundTrip(request) }() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + rsp, err := client.RoundTrip(request) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp).ToNot(BeNil()) + data := make([]byte, 11) + rsp.Body.Read(data) + Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1)) + Expect(data).To(Equal([]byte("not gzipped"))) + close(done) + }() - Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) dataStream.dataToRead.Write([]byte("not gzipped")) - client.responses[5] <- response - Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil()) - Expect(doErr).ToNot(HaveOccurred()) + injectResponse(5, response) headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes())) Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip")) - data := make([]byte, 11) - doRsp.Body.Read(data) - Expect(doRsp.ContentLength).ToNot(BeEquivalentTo(-1)) - Expect(data).To(Equal([]byte("not gzipped"))) + Eventually(done).Should(BeClosed()) }) It("doesn't add the gzip header for requests that have the accept-enconding set", func() { request.Header.Add("accept-encoding", "gzip") - var doRsp *http.Response - var doErr error - go func() { doRsp, doErr = client.RoundTrip(request) }() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + rsp, err := client.RoundTrip(request) + Expect(err).ToNot(HaveOccurred()) + data := make([]byte, 12) + _, err = rsp.Body.Read(data) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1)) + Expect(data).To(Equal([]byte("gzipped data"))) + close(done) + }() - Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) dataStream.dataToRead.Write([]byte("gzipped data")) - client.responses[5] <- response - Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil()) - Expect(doErr).ToNot(HaveOccurred()) + injectResponse(5, response) headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes())) Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip")) - data := make([]byte, 12) - doRsp.Body.Read(data) - Expect(doRsp.ContentLength).ToNot(BeEquivalentTo(-1)) - Expect(data).To(Equal([]byte("gzipped data"))) + Eventually(done).Should(BeClosed()) }) }) diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 55ffd33d4..8427f1bf4 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -32,11 +32,16 @@ type mockSession struct { streamToAccept quic.Stream streamsToOpen []quic.Stream blockOpenStreamSync bool + blockOpenStreamChan chan struct{} // close this chan (or call Close) to make OpenStreamSync return streamOpenErr error ctx context.Context ctxCancel context.CancelFunc } +func newMockSession() *mockSession { + return &mockSession{blockOpenStreamChan: make(chan struct{})} +} + func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) { return s.dataStream, nil } @@ -51,14 +56,17 @@ func (s *mockSession) OpenStream() (quic.Stream, error) { } func (s *mockSession) OpenStreamSync() (quic.Stream, error) { if s.blockOpenStreamSync { - time.Sleep(time.Hour) + <-s.blockOpenStreamChan } return s.OpenStream() } func (s *mockSession) Close(e error) error { - s.closed = true s.closedWithError = e s.ctxCancel() + if !s.closed { + close(s.blockOpenStreamChan) + } + s.closed = true return nil } func (s *mockSession) LocalAddr() net.Addr { @@ -88,7 +96,8 @@ var _ = Describe("H2 server", func() { } dataStream = newMockStream(0) close(dataStream.unblockRead) - session = &mockSession{dataStream: dataStream} + session = newMockSession() + session.dataStream = dataStream session.ctx, session.ctxCancel = context.WithCancel(context.Background()) origQuicListenAddr = quicListenAddr })