forked from quic-go/quic-go
fix race condition when handling the header stream fails in h2quic client
This commit is contained in:
@@ -40,6 +40,7 @@ type client struct {
|
||||
session quic.Session
|
||||
headerStream quic.Stream
|
||||
headerErr *qerr.QuicError
|
||||
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
|
||||
requestWriter *requestWriter
|
||||
|
||||
responses map[protocol.StreamID]chan *http.Response
|
||||
@@ -58,7 +59,8 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *
|
||||
TLSConfig: tlsConfig,
|
||||
RequestConnectionIDTruncation: true,
|
||||
},
|
||||
opts: opts,
|
||||
opts: opts,
|
||||
headerErrored: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,7 +111,7 @@ func (c *client) handleHeaderStream() {
|
||||
}
|
||||
|
||||
c.mutex.RLock()
|
||||
headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
||||
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
||||
c.mutex.RUnlock()
|
||||
if !ok {
|
||||
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
|
||||
@@ -120,16 +122,12 @@ func (c *client) handleHeaderStream() {
|
||||
if err != nil {
|
||||
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
|
||||
}
|
||||
headerChan <- rsp
|
||||
responseChan <- rsp
|
||||
}
|
||||
|
||||
// stop all running request
|
||||
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
|
||||
c.mutex.Lock()
|
||||
for _, responseChan := range c.responses {
|
||||
close(responseChan)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
close(c.headerErrored)
|
||||
}
|
||||
|
||||
// Roundtrip executes a request and returns a response
|
||||
@@ -197,15 +195,15 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
c.mutex.Lock()
|
||||
delete(c.responses, dataStream.StreamID())
|
||||
c.mutex.Unlock()
|
||||
if res == nil { // an error occured on the header stream
|
||||
c.Close(c.headerErr)
|
||||
return nil, c.headerErr
|
||||
}
|
||||
case err := <-resc:
|
||||
bodySent = true
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case <-c.headerErrored:
|
||||
// an error occured on the header stream
|
||||
c.Close(c.headerErr)
|
||||
return nil, c.headerErr
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -165,6 +165,7 @@ var _ = Describe("Client", func() {
|
||||
StatusCode: 418,
|
||||
}
|
||||
Expect(client.responses[5]).ToNot(BeClosed())
|
||||
Expect(client.headerErrored).ToNot(BeClosed())
|
||||
client.responses[5] <- rsp
|
||||
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
||||
Expect(doErr).ToNot(HaveOccurred())
|
||||
@@ -177,6 +178,7 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
|
||||
It("closes the quic client when encountering an error on the header stream", func(done Done) {
|
||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
||||
var doReturned bool
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -187,13 +189,31 @@ var _ = Describe("Client", func() {
|
||||
doReturned = true
|
||||
}()
|
||||
|
||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
||||
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")))
|
||||
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
|
||||
close(done)
|
||||
}, 2)
|
||||
|
||||
It("returns subsequent request if there was an error on the header stream before", func(done Done) {
|
||||
expectedErr := qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
|
||||
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
|
||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
||||
var firstReqReturned bool
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := client.RoundTrip(request)
|
||||
Expect(err).To(MatchError(expectedErr))
|
||||
firstReqReturned = true
|
||||
}()
|
||||
|
||||
Eventually(func() bool { return firstReqReturned }).Should(BeTrue())
|
||||
// now that the first request failed due to an error on the header stream, try another request
|
||||
_, err := client.RoundTrip(request)
|
||||
Expect(err).To(MatchError(expectedErr))
|
||||
close(done)
|
||||
})
|
||||
|
||||
It("blocks if no stream is available", func() {
|
||||
session.streamsToOpen = []quic.Stream{headerStream}
|
||||
session.blockOpenStreamSync = true
|
||||
@@ -455,7 +475,7 @@ var _ = Describe("Client", func() {
|
||||
handlerReturned = true
|
||||
}()
|
||||
|
||||
Eventually(client.responses[23]).Should(BeClosed())
|
||||
Eventually(client.headerErrored).Should(BeClosed())
|
||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
|
||||
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
|
||||
})
|
||||
@@ -473,7 +493,7 @@ var _ = Describe("Client", func() {
|
||||
handlerReturned = true
|
||||
}()
|
||||
|
||||
Eventually(client.responses[23]).Should(BeClosed())
|
||||
Eventually(client.headerErrored).Should(BeClosed())
|
||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")))
|
||||
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user