forked from quic-go/quic-go
handle request cancelations while waiting for handshake completion
This commit is contained in:
@@ -101,7 +101,6 @@ func (c *client) dial() error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
<-c.session.HandshakeComplete().Done()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,6 +150,13 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||||||
return nil, c.handshakeErr
|
return nil, c.handshakeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wait for the handshake to complete
|
||||||
|
select {
|
||||||
|
case <-c.session.HandshakeComplete().Done():
|
||||||
|
case <-req.Context().Done():
|
||||||
|
return nil, req.Context().Err()
|
||||||
|
}
|
||||||
|
|
||||||
str, err := c.session.OpenStreamSync(req.Context())
|
str, err := c.session.OpenStreamSync(req.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -310,6 +310,21 @@ var _ = Describe("Client", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
Context("request cancellations", func() {
|
Context("request cancellations", func() {
|
||||||
|
It("cancels a request while waiting for the handshake to complete", func() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
req := request.WithContext(ctx)
|
||||||
|
sess.EXPECT().HandshakeComplete().Return(context.Background())
|
||||||
|
|
||||||
|
errChan := make(chan error)
|
||||||
|
go func() {
|
||||||
|
_, err := client.RoundTrip(req)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
Consistently(errChan).ShouldNot(Receive())
|
||||||
|
cancel()
|
||||||
|
Eventually(errChan).Should(Receive(MatchError("context canceled")))
|
||||||
|
})
|
||||||
|
|
||||||
It("cancels a request while the request is still in flight", func() {
|
It("cancels a request while the request is still in flight", func() {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
req := request.WithContext(ctx)
|
req := request.WithContext(ctx)
|
||||||
|
|||||||
@@ -135,8 +135,8 @@ var _ = Describe("RoundTripper", func() {
|
|||||||
It("reuses existing clients", func() {
|
It("reuses existing clients", func() {
|
||||||
closed := make(chan struct{})
|
closed := make(chan struct{})
|
||||||
testErr := errors.New("test err")
|
testErr := errors.New("test err")
|
||||||
session.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
|
||||||
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
|
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
|
||||||
|
session.EXPECT().HandshakeComplete().Return(handshakeCtx).Times(2)
|
||||||
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
|
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
|
||||||
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
|
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
||||||
|
|||||||
Reference in New Issue
Block a user