forked from quic-go/quic-go
fix race condition when handling the header stream fails in h2quic client
This commit is contained in:
@@ -40,6 +40,7 @@ type client struct {
|
|||||||
session quic.Session
|
session quic.Session
|
||||||
headerStream quic.Stream
|
headerStream quic.Stream
|
||||||
headerErr *qerr.QuicError
|
headerErr *qerr.QuicError
|
||||||
|
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
|
||||||
requestWriter *requestWriter
|
requestWriter *requestWriter
|
||||||
|
|
||||||
responses map[protocol.StreamID]chan *http.Response
|
responses map[protocol.StreamID]chan *http.Response
|
||||||
@@ -59,6 +60,7 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *
|
|||||||
RequestConnectionIDTruncation: true,
|
RequestConnectionIDTruncation: true,
|
||||||
},
|
},
|
||||||
opts: opts,
|
opts: opts,
|
||||||
|
headerErrored: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,7 +111,7 @@ func (c *client) handleHeaderStream() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.mutex.RLock()
|
c.mutex.RLock()
|
||||||
headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
||||||
c.mutex.RUnlock()
|
c.mutex.RUnlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
|
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
|
||||||
@@ -120,16 +122,12 @@ func (c *client) handleHeaderStream() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
|
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
|
||||||
}
|
}
|
||||||
headerChan <- rsp
|
responseChan <- rsp
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop all running request
|
// stop all running request
|
||||||
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
|
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
|
||||||
c.mutex.Lock()
|
close(c.headerErrored)
|
||||||
for _, responseChan := range c.responses {
|
|
||||||
close(responseChan)
|
|
||||||
}
|
|
||||||
c.mutex.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Roundtrip executes a request and returns a response
|
// Roundtrip executes a request and returns a response
|
||||||
@@ -197,15 +195,15 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
delete(c.responses, dataStream.StreamID())
|
delete(c.responses, dataStream.StreamID())
|
||||||
c.mutex.Unlock()
|
c.mutex.Unlock()
|
||||||
if res == nil { // an error occured on the header stream
|
|
||||||
c.Close(c.headerErr)
|
|
||||||
return nil, c.headerErr
|
|
||||||
}
|
|
||||||
case err := <-resc:
|
case err := <-resc:
|
||||||
bodySent = true
|
bodySent = true
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
case <-c.headerErrored:
|
||||||
|
// an error occured on the header stream
|
||||||
|
c.Close(c.headerErr)
|
||||||
|
return nil, c.headerErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -165,6 +165,7 @@ var _ = Describe("Client", func() {
|
|||||||
StatusCode: 418,
|
StatusCode: 418,
|
||||||
}
|
}
|
||||||
Expect(client.responses[5]).ToNot(BeClosed())
|
Expect(client.responses[5]).ToNot(BeClosed())
|
||||||
|
Expect(client.headerErrored).ToNot(BeClosed())
|
||||||
client.responses[5] <- rsp
|
client.responses[5] <- rsp
|
||||||
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
||||||
Expect(doErr).ToNot(HaveOccurred())
|
Expect(doErr).ToNot(HaveOccurred())
|
||||||
@@ -177,6 +178,7 @@ var _ = Describe("Client", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("closes the quic client when encountering an error on the header stream", func(done Done) {
|
It("closes the quic client when encountering an error on the header stream", func(done Done) {
|
||||||
|
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
||||||
var doReturned bool
|
var doReturned bool
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
@@ -187,13 +189,31 @@ var _ = Describe("Client", func() {
|
|||||||
doReturned = true
|
doReturned = true
|
||||||
}()
|
}()
|
||||||
|
|
||||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
|
||||||
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
||||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")))
|
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")))
|
||||||
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
|
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
|
||||||
close(done)
|
close(done)
|
||||||
}, 2)
|
}, 2)
|
||||||
|
|
||||||
|
It("returns subsequent request if there was an error on the header stream before", func(done Done) {
|
||||||
|
expectedErr := qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
|
||||||
|
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
|
||||||
|
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
||||||
|
var firstReqReturned bool
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
_, err := client.RoundTrip(request)
|
||||||
|
Expect(err).To(MatchError(expectedErr))
|
||||||
|
firstReqReturned = true
|
||||||
|
}()
|
||||||
|
|
||||||
|
Eventually(func() bool { return firstReqReturned }).Should(BeTrue())
|
||||||
|
// now that the first request failed due to an error on the header stream, try another request
|
||||||
|
_, err := client.RoundTrip(request)
|
||||||
|
Expect(err).To(MatchError(expectedErr))
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
|
|
||||||
It("blocks if no stream is available", func() {
|
It("blocks if no stream is available", func() {
|
||||||
session.streamsToOpen = []quic.Stream{headerStream}
|
session.streamsToOpen = []quic.Stream{headerStream}
|
||||||
session.blockOpenStreamSync = true
|
session.blockOpenStreamSync = true
|
||||||
@@ -455,7 +475,7 @@ var _ = Describe("Client", func() {
|
|||||||
handlerReturned = true
|
handlerReturned = true
|
||||||
}()
|
}()
|
||||||
|
|
||||||
Eventually(client.responses[23]).Should(BeClosed())
|
Eventually(client.headerErrored).Should(BeClosed())
|
||||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
|
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
|
||||||
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
|
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
|
||||||
})
|
})
|
||||||
@@ -473,7 +493,7 @@ var _ = Describe("Client", func() {
|
|||||||
handlerReturned = true
|
handlerReturned = true
|
||||||
}()
|
}()
|
||||||
|
|
||||||
Eventually(client.responses[23]).Should(BeClosed())
|
Eventually(client.headerErrored).Should(BeClosed())
|
||||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")))
|
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")))
|
||||||
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
|
Eventually(func() bool { return handlerReturned }).Should(BeTrue())
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user