Make RoundTripper.RoundTrip(...) return if client timeout expires

Currently, the implementation of h2quic.RoundTripper.RoundTrip(req
*http.Request) ignores the context of req. As a result, if the
RoundTripper is used as transport of an http.Client with a timeout value
set, that is ignored.

For example, in the following snippet, client.Do(req) does not promptly
return if the task takes more than client.Timeout to complete.

    client := http.Client{
        Timeout: 50 * time.Millisecond,
        Transport = &h2quic.RoundTripper{}
    }
    req, err := http.NewRequest("GET", "https://www.example.com", nil)
    response, err := client.Do(req)

This commit updates h2quic.client.RoundTrip(req *http.Request) to return
an error if the request is cancelled.
This commit is contained in:
Lorenzo Saino
2018-02-07 14:37:49 +00:00
parent e26c1f09de
commit 10acc677aa
3 changed files with 78 additions and 7 deletions

View File

@@ -202,6 +202,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
bodySent = true
}
ctx := req.Context()
for !(bodySent && receivedResponse) {
select {
case res = <-responseChan:
@@ -214,6 +215,14 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
if err != nil {
return nil, err
}
case <-ctx.Done():
// error code 6 signals that stream was canceled
dataStream.CancelRead(6)
dataStream.CancelWrite(6)
c.mutex.Lock()
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
return nil, ctx.Err()
case <-c.headerErrored:
// an error occurred on the header stream
_ = c.CloseWithError(c.headerErr)

View File

@@ -231,6 +231,67 @@ var _ = Describe("Client", func() {
Eventually(done).Should(BeClosed())
})
It("errors if a request without a body is canceled", func() {
done := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer GinkgoRecover()
request = request.WithContext(ctx)
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(context.Canceled))
Expect(rsp).To(BeNil())
close(done)
}()
cancel()
Eventually(done).Should(BeClosed())
Expect(dataStream.reset).To(BeTrue())
Expect(dataStream.canceledWrite).To(BeTrue())
Expect(client.headerErrored).ToNot(BeClosed())
})
It("errors if a request with a body is canceled after the body is sent", func() {
done := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer GinkgoRecover()
request = request.WithContext(ctx)
request.Body = &mockBody{}
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(context.Canceled))
Expect(rsp).To(BeNil())
close(done)
}()
time.Sleep(10 * time.Millisecond)
cancel()
Eventually(done).Should(BeClosed())
Expect(dataStream.reset).To(BeTrue())
Expect(dataStream.canceledWrite).To(BeTrue())
Expect(client.headerErrored).ToNot(BeClosed())
})
It("errors if a request with a body is canceled before the body is sent", func() {
done := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer GinkgoRecover()
request = request.WithContext(ctx)
request.Body = &mockBody{}
cancel()
time.Sleep(10 * time.Millisecond)
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(context.Canceled))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(dataStream.reset).To(BeTrue())
Expect(dataStream.canceledWrite).To(BeTrue())
Expect(client.headerErrored).ToNot(BeClosed())
})
It("closes the quic client when encountering an error on the header stream", func() {
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
done := make(chan struct{})

View File

@@ -18,12 +18,13 @@ import (
)
type mockStream struct {
id protocol.StreamID
dataToRead bytes.Buffer
dataWritten bytes.Buffer
reset bool
closed bool
remoteClosed bool
id protocol.StreamID
dataToRead bytes.Buffer
dataWritten bytes.Buffer
reset bool
canceledWrite bool
closed bool
remoteClosed bool
unblockRead chan struct{}
ctx context.Context
@@ -43,7 +44,7 @@ func newMockStream(id protocol.StreamID) *mockStream {
func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil }
func (s *mockStream) CancelRead(quic.ErrorCode) error { s.reset = true; return nil }
func (s *mockStream) CancelWrite(quic.ErrorCode) error { panic("not implemented") }
func (s *mockStream) CancelWrite(quic.ErrorCode) error { s.canceledWrite = true; return nil }
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() }
func (s mockStream) StreamID() protocol.StreamID { return s.id }
func (s *mockStream) Context() context.Context { return s.ctx }