From 05d7bc91ef8fc8396b526410f6f227497e03b60a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 6 Jun 2017 16:56:39 +0200 Subject: [PATCH] fix race condition when handling the header stream fails in h2quic client --- h2quic/client.go | 22 ++++++++++------------ h2quic/client_test.go | 26 +++++++++++++++++++++++--- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/h2quic/client.go b/h2quic/client.go index fb569eae..d0bb96cc 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -40,6 +40,7 @@ type client struct { session quic.Session headerStream quic.Stream headerErr *qerr.QuicError + headerErrored chan struct{} // this channel is closed if an error occurs on the header stream requestWriter *requestWriter responses map[protocol.StreamID]chan *http.Response @@ -58,7 +59,8 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) * TLSConfig: tlsConfig, RequestConnectionIDTruncation: true, }, - opts: opts, + opts: opts, + headerErrored: make(chan struct{}), } } @@ -109,7 +111,7 @@ func (c *client) handleHeaderStream() { } c.mutex.RLock() - headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] + responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] c.mutex.RUnlock() if !ok { c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) @@ -120,16 +122,12 @@ func (c *client) handleHeaderStream() { if err != nil { c.headerErr = qerr.Error(qerr.InternalError, err.Error()) } - headerChan <- rsp + responseChan <- rsp } // stop all running request utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error()) - c.mutex.Lock() - for _, responseChan := range c.responses { - close(responseChan) - } - c.mutex.Unlock() + close(c.headerErrored) } // Roundtrip executes a request and returns a response @@ -197,15 +195,15 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { c.mutex.Lock() delete(c.responses, dataStream.StreamID()) c.mutex.Unlock() - if res == nil { // an error occured on the header stream - c.Close(c.headerErr) - return nil, c.headerErr - } case err := <-resc: bodySent = true if err != nil { return nil, err } + case <-c.headerErrored: + // an error occured on the header stream + c.Close(c.headerErr) + return nil, c.headerErr } } diff --git a/h2quic/client_test.go b/h2quic/client_test.go index a9dc3861..6f66425d 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -165,6 +165,7 @@ var _ = Describe("Client", func() { StatusCode: 418, } Expect(client.responses[5]).ToNot(BeClosed()) + Expect(client.headerErrored).ToNot(BeClosed()) client.responses[5] <- rsp Eventually(func() bool { return doReturned }).Should(BeTrue()) Expect(doErr).ToNot(HaveOccurred()) @@ -177,6 +178,7 @@ var _ = Describe("Client", func() { }) It("closes the quic client when encountering an error on the header stream", func(done Done) { + headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100)) var doReturned bool go func() { defer GinkgoRecover() @@ -187,13 +189,31 @@ var _ = Describe("Client", func() { doReturned = true }() - headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100)) Eventually(func() bool { return doReturned }).Should(BeTrue()) Expect(client.headerErr).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame"))) Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr)) close(done) }, 2) + It("returns subsequent request if there was an error on the header stream before", func(done Done) { + expectedErr := qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") + session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)} + headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100)) + var firstReqReturned bool + go func() { + defer GinkgoRecover() + _, err := client.RoundTrip(request) + Expect(err).To(MatchError(expectedErr)) + firstReqReturned = true + }() + + Eventually(func() bool { return firstReqReturned }).Should(BeTrue()) + // now that the first request failed due to an error on the header stream, try another request + _, err := client.RoundTrip(request) + Expect(err).To(MatchError(expectedErr)) + close(done) + }) + It("blocks if no stream is available", func() { session.streamsToOpen = []quic.Stream{headerStream} session.blockOpenStreamSync = true @@ -455,7 +475,7 @@ var _ = Describe("Client", func() { handlerReturned = true }() - Eventually(client.responses[23]).Should(BeClosed()) + Eventually(client.headerErrored).Should(BeClosed()) Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame"))) Eventually(func() bool { return handlerReturned }).Should(BeTrue()) }) @@ -473,7 +493,7 @@ var _ = Describe("Client", func() { handlerReturned = true }() - Eventually(client.responses[23]).Should(BeClosed()) + Eventually(client.headerErrored).Should(BeClosed()) Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields"))) Eventually(func() bool { return handlerReturned }).Should(BeTrue()) })