fix race condition when handling the header stream fails in h2quic client

This commit is contained in:
Marten Seemann
2017-06-06 16:56:39 +02:00
parent 500d9889f5
commit 05d7bc91ef
2 changed files with 33 additions and 15 deletions

View File

@@ -40,6 +40,7 @@ type client struct {
session quic.Session
headerStream quic.Stream
headerErr *qerr.QuicError
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response
@@ -58,7 +59,8 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *
TLSConfig: tlsConfig,
RequestConnectionIDTruncation: true,
},
opts: opts,
opts: opts,
headerErrored: make(chan struct{}),
}
}
@@ -109,7 +111,7 @@ func (c *client) handleHeaderStream() {
}
c.mutex.RLock()
headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock()
if !ok {
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
@@ -120,16 +122,12 @@ func (c *client) handleHeaderStream() {
if err != nil {
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
}
headerChan <- rsp
responseChan <- rsp
}
// stop all running request
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
c.mutex.Lock()
for _, responseChan := range c.responses {
close(responseChan)
}
c.mutex.Unlock()
close(c.headerErrored)
}
// Roundtrip executes a request and returns a response
@@ -197,15 +195,15 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
c.mutex.Lock()
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
if res == nil { // an error occured on the header stream
c.Close(c.headerErr)
return nil, c.headerErr
}
case err := <-resc:
bodySent = true
if err != nil {
return nil, err
}
case <-c.headerErrored:
// an error occured on the header stream
c.Close(c.headerErr)
return nil, c.headerErr
}
}

View File

@@ -165,6 +165,7 @@ var _ = Describe("Client", func() {
StatusCode: 418,
}
Expect(client.responses[5]).ToNot(BeClosed())
Expect(client.headerErrored).ToNot(BeClosed())
client.responses[5] <- rsp
Eventually(func() bool { return doReturned }).Should(BeTrue())
Expect(doErr).ToNot(HaveOccurred())
@@ -177,6 +178,7 @@ var _ = Describe("Client", func() {
})
It("closes the quic client when encountering an error on the header stream", func(done Done) {
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
var doReturned bool
go func() {
defer GinkgoRecover()
@@ -187,13 +189,31 @@ var _ = Describe("Client", func() {
doReturned = true
}()
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
Eventually(func() bool { return doReturned }).Should(BeTrue())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")))
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
close(done)
}, 2)
It("returns subsequent request if there was an error on the header stream before", func(done Done) {
expectedErr := qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
var firstReqReturned bool
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(request)
Expect(err).To(MatchError(expectedErr))
firstReqReturned = true
}()
Eventually(func() bool { return firstReqReturned }).Should(BeTrue())
// now that the first request failed due to an error on the header stream, try another request
_, err := client.RoundTrip(request)
Expect(err).To(MatchError(expectedErr))
close(done)
})
It("blocks if no stream is available", func() {
session.streamsToOpen = []quic.Stream{headerStream}
session.blockOpenStreamSync = true
@@ -455,7 +475,7 @@ var _ = Describe("Client", func() {
handlerReturned = true
}()
Eventually(client.responses[23]).Should(BeClosed())
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
})
@@ -473,7 +493,7 @@ var _ = Describe("Client", func() {
handlerReturned = true
}()
Eventually(client.responses[23]).Should(BeClosed())
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")))
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
})