diff --git a/http3/client.go b/http3/client.go index 56b766b92..4979a9427 100644 --- a/http3/client.go +++ b/http3/client.go @@ -2,7 +2,6 @@ package http3 import ( "bytes" - "context" "crypto/tls" "errors" "fmt" @@ -150,7 +149,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { return nil, c.handshakeErr } - str, err := c.session.OpenStreamSync(context.Background()) + str, err := c.session.OpenStreamSync(req.Context()) if err != nil { return nil, err } diff --git a/http3/client_test.go b/http3/client_test.go index 1ce69aee8..ddb3ddb28 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -301,7 +301,7 @@ var _ = Describe("Client", func() { It("cancels a request while the request is still in flight", func() { ctx, cancel := context.WithCancel(context.Background()) req := request.WithContext(ctx) - sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + sess.EXPECT().OpenStreamSync(ctx).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Close().MaxTimes(1) @@ -333,7 +333,7 @@ var _ = Describe("Client", func() { ctx, cancel := context.WithCancel(context.Background()) req := request.WithContext(ctx) - sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + sess.EXPECT().OpenStreamSync(ctx).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Close().MaxTimes(1)