From 18422ad1c44726209a7e854a473d61b1c78bfdfc Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 20 Apr 2024 11:42:33 +0200 Subject: [PATCH] http3: remove RoundTripOpt.CheckSettings (#4416) The settings can be obtained from the SingleDestinationRoundTripper. --- http3/client.go | 23 ++--- http3/client_test.go | 121 ++++++++++++++------------ http3/mock_singleroundtripper_test.go | 24 ++--- http3/roundtrip.go | 8 +- http3/roundtrip_test.go | 24 ++--- integrationtests/self/http_test.go | 31 ++++--- 6 files changed, 116 insertions(+), 115 deletions(-) diff --git a/http3/client.go b/http3/client.go index 5341e6d3..a32844e3 100644 --- a/http3/client.go +++ b/http3/client.go @@ -69,6 +69,8 @@ type SingleDestinationRoundTripper struct { decoder *qpack.Decoder } +var _ http.RoundTripper = &SingleDestinationRoundTripper{} + func (c *SingleDestinationRoundTripper) Start() Connection { c.initOnce.Do(func() { c.init() }) return c.hconn @@ -141,11 +143,11 @@ func (c *SingleDestinationRoundTripper) maxHeaderBytes() uint64 { return uint64(c.MaxResponseHeaderBytes) } -// RoundTripOpt executes a request and returns a response -func (c *SingleDestinationRoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { +// RoundTrip executes a request and returns a response +func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { c.initOnce.Do(func() { c.init() }) - rsp, err := c.roundTripOpt(req, opt) + rsp, err := c.roundTrip(req) if err != nil && req.Context().Err() != nil { // if the context was canceled, return the context cancellation error err = req.Context().Err() @@ -153,7 +155,7 @@ func (c *SingleDestinationRoundTripper) RoundTripOpt(req *http.Request, opt Roun return rsp, err } -func (c *SingleDestinationRoundTripper) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { +func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Response, error) { // Immediately send out this request, if this is a 0-RTT request. switch req.Method { case MethodGet0RTT: @@ -178,19 +180,6 @@ func (c *SingleDestinationRoundTripper) roundTripOpt(req *http.Request, opt Roun } } - if opt.CheckSettings != nil { - connCtx := c.Connection.Context() - // wait for the server's SETTINGS frame to arrive - select { - case <-c.hconn.ReceivedSettings(): - case <-connCtx.Done(): - return nil, context.Cause(connCtx) - } - if err := opt.CheckSettings(*c.hconn.Settings()); err != nil { - return nil, err - } - } - str, err := c.Connection.OpenStreamSync(req.Context()) if err != nil { return nil, err diff --git a/http3/client_test.go b/http3/client_test.go index 062ce6a7..2c7eb4bf 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -95,7 +95,7 @@ var _ = Describe("Client", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) conn.EXPECT().Context().Return(ctx).AnyTimes() - _, err := rt.RoundTripOpt(request, RoundTripOpt{}) + _, err := rt.RoundTrip(request) Expect(err).To(MatchError("done")) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -123,7 +123,7 @@ var _ = Describe("Client", func() { ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() - _, err := rt.RoundTripOpt(request, RoundTripOpt{}) + _, err := rt.RoundTrip(request) Expect(err).To(MatchError("done")) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) }) @@ -150,7 +150,7 @@ var _ = Describe("Client", func() { ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() - _, err := rt.RoundTripOpt(request, RoundTripOpt{}) + _, err := rt.RoundTrip(request) Expect(err).To(MatchError("done")) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) }) @@ -179,7 +179,7 @@ var _ = Describe("Client", func() { ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() - _, err := rt.RoundTripOpt(request, RoundTripOpt{}) + _, err := rt.RoundTrip(request) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -242,7 +242,7 @@ var _ = Describe("Client", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) conn.EXPECT().Context().Return(ctx).AnyTimes() - _, err := rt.RoundTripOpt(req, RoundTripOpt{}) + _, err := rt.RoundTrip(req) Expect(err).To(MatchError("done")) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -271,7 +271,7 @@ var _ = Describe("Client", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - _, err := rt.RoundTripOpt(req, RoundTripOpt{}) + _, err := rt.RoundTrip(req) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -301,7 +301,7 @@ var _ = Describe("Client", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - _, err := rt.RoundTripOpt(req, RoundTripOpt{}) + _, err := rt.RoundTrip(req) Expect(err).To(MatchError("done")) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -309,58 +309,71 @@ var _ = Describe("Client", func() { }) Context("SETTINGS handling", func() { - var ( - req *http.Request - conn *mockquic.MockEarlyConnection - rt *SingleDestinationRoundTripper - settingsFrameWritten chan struct{} - ) - testDone := make(chan struct{}, 1) + var settingsFrameWritten chan struct{} BeforeEach(func() { settingsFrameWritten = make(chan struct{}) controlStr := mockquic.NewMockStream(mockCtrl) + var buf bytes.Buffer controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { defer GinkgoRecover() + buf.Write(b) close(settingsFrameWritten) return len(b), nil }) - conn = mockquic.NewMockEarlyConnection(mockCtrl) + conn := mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().OpenStreamSync(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { + <-settingsFrameWritten + return nil, errors.New("test done") + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-settingsFrameWritten + return nil, errors.New("test done") + }).AnyTimes() conn.EXPECT().HandshakeComplete().Return(handshakeChan) - rt = &SingleDestinationRoundTripper{Connection: conn} - var err error - req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) + rt := &SingleDestinationRoundTripper{ + Connection: conn, + EnableDatagrams: true, + } + req, err := http.NewRequest(http.MethodGet, "https://quic-go.net", nil) Expect(err).ToNot(HaveOccurred()) + _, err = rt.RoundTrip(req) + Expect(err).To(MatchError("test done")) + t, err := quicvarint.Read(&buf) + Expect(err).ToNot(HaveOccurred()) + Expect(t).To(BeEquivalentTo(streamTypeControlStream)) + settings, err := parseSettingsFrame(&buf, uint64(buf.Len())) + Expect(err).ToNot(HaveOccurred()) + Expect(settings.Datagram).To(BeTrue()) }) - AfterEach(func() { - testDone <- struct{}{} - Eventually(settingsFrameWritten).Should(BeClosed()) - }) - - It("allows the client to reject the SETTINGS using the CheckSettings RoundTripOpt", func() { + It("receives SETTINGS", func() { + done := make(chan struct{}) + conn := mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().OpenUniStream().DoAndReturn(func() (quic.SendStream, error) { + <-done + return nil, errors.New("test done") + }).MaxTimes(1) b := quicvarint.Append(nil, streamTypeControlStream) - b = (&settingsFrame{}).Append(b) + b = (&settingsFrame{Datagram: true}).Append(b) r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() - // Don't EXPECT any call to OpenStreamSync. - // When the SETTINGS are rejected, we don't even open the request stream. + conn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return controlStr, nil - }) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-testDone + <-done return nil, errors.New("test done") }) - conn.EXPECT().Context().Return(context.Background()) - _, err := rt.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error { - return errors.New("wrong settings") - }}) - rt.Connection = conn - Expect(err).To(MatchError("wrong settings")) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + + rt := &SingleDestinationRoundTripper{Connection: conn} + hconn := rt.Start() + Eventually(hconn.ReceivedSettings()).Should(BeClosed()) + settings := hconn.Settings() + Expect(settings.EnableDatagram).To(BeTrue()) + // test shutdown + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) + close(done) }) }) @@ -428,7 +441,7 @@ var _ = Describe("Client", func() { conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) conn.EXPECT().HandshakeComplete().Return(handshakeChan) - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTrip(req) Expect(err).To(MatchError(testErr)) }) @@ -447,7 +460,7 @@ var _ = Describe("Client", func() { str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr }) - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTrip(req) Expect(err).To(MatchError(testErr)) Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", serialized)) // make sure the request wasn't modified @@ -467,7 +480,7 @@ var _ = Describe("Client", func() { str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Close() str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() - rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) + rsp, err := cl.RoundTrip(req) Expect(err).ToNot(HaveOccurred()) Expect(rsp.Proto).To(Equal("HTTP/3.0")) Expect(rsp.ProtoMajor).To(Equal(3)) @@ -506,7 +519,7 @@ var _ = Describe("Client", func() { <-done return 0, testErr }) - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTrip(req) Expect(err).To(MatchError(testErr)) hfs := decodeHeader(strBuf) Expect(hfs).To(HaveKeyWithValue(":method", "POST")) @@ -532,7 +545,7 @@ var _ = Describe("Client", func() { <-done return 0, errors.New("done") }) - cl.RoundTripOpt(req, RoundTripOpt{}) + cl.RoundTrip(req) Expect(strBuf.String()).To(ContainSubstring("request")) Expect(strBuf.String()).ToNot(ContainSubstring("request body")) }) @@ -556,7 +569,7 @@ var _ = Describe("Client", func() { }) closed := make(chan struct{}) str.EXPECT().Close().Do(func() error { close(closed); return nil }) - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTrip(req) Expect(err).To(MatchError(testErr)) Eventually(closed).Should(BeClosed()) }) @@ -568,7 +581,7 @@ var _ = Describe("Client", func() { r := bytes.NewReader(b) str.EXPECT().Close().Do(func() error { close(closed); return nil }) str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTrip(req) Expect(err).To(MatchError("http3: expected first frame to be a HEADERS frame")) Eventually(closed).Should(BeClosed()) }) @@ -587,7 +600,7 @@ var _ = Describe("Client", func() { closed := make(chan struct{}) str.EXPECT().Close().Do(func() error { close(closed); return nil }) str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTrip(req) Expect(err).To(HaveOccurred()) Eventually(closed).Should(BeClosed()) }) @@ -601,7 +614,7 @@ var _ = Describe("Client", func() { closed := make(chan struct{}) str.EXPECT().Close().Do(func() error { close(closed); return nil }) str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTrip(req) Expect(err).To(MatchError("http3: HEADERS frame too large: 1338 bytes (max: 1337)")) Eventually(closed).Should(BeClosed()) }) @@ -627,7 +640,7 @@ var _ = Describe("Client", func() { errChan := make(chan error) go func() { - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTrip(req) errChan <- err }() Consistently(errChan).ShouldNot(Receive()) @@ -658,7 +671,7 @@ var _ = Describe("Client", func() { <-canceled return 0, errors.New("test done") }) - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTrip(req) Expect(err).To(MatchError(context.Canceled)) Eventually(done).Should(BeClosed()) }) @@ -679,7 +692,7 @@ var _ = Describe("Client", func() { str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTrip(req) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(done).Should(BeClosed()) @@ -703,7 +716,7 @@ var _ = Describe("Client", func() { ) testErr := errors.New("test done") str.EXPECT().Read(gomock.Any()).Return(0, testErr) - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTrip(req) Expect(err).To(MatchError(testErr)) hfs := decodeHeader(buf) Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) @@ -725,7 +738,7 @@ var _ = Describe("Client", func() { ) testErr := errors.New("test done") str.EXPECT().Read(gomock.Any()).Return(0, testErr) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := client.RoundTrip(req) Expect(err).To(MatchError(testErr)) hfs := decodeHeader(buf) Expect(hfs).ToNot(HaveKey("accept-encoding")) @@ -747,7 +760,7 @@ var _ = Describe("Client", func() { str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Close() - rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) + rsp, err := cl.RoundTrip(req) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(rsp.Body) Expect(err).ToNot(HaveOccurred()) @@ -770,7 +783,7 @@ var _ = Describe("Client", func() { str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Close() - rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) + rsp, err := cl.RoundTrip(req) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(rsp.Body) Expect(err).ToNot(HaveOccurred()) diff --git a/http3/mock_singleroundtripper_test.go b/http3/mock_singleroundtripper_test.go index 1549a1be..59dc3f53 100644 --- a/http3/mock_singleroundtripper_test.go +++ b/http3/mock_singleroundtripper_test.go @@ -79,41 +79,41 @@ func (c *MockSingleRoundTripperOpenRequestStreamCall) DoAndReturn(f func(context return c } -// RoundTripOpt mocks base method. -func (m *MockSingleRoundTripper) RoundTripOpt(arg0 *http.Request, arg1 RoundTripOpt) (*http.Response, error) { +// RoundTrip mocks base method. +func (m *MockSingleRoundTripper) RoundTrip(arg0 *http.Request) (*http.Response, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RoundTripOpt", arg0, arg1) + ret := m.ctrl.Call(m, "RoundTrip", arg0) ret0, _ := ret[0].(*http.Response) ret1, _ := ret[1].(error) return ret0, ret1 } -// RoundTripOpt indicates an expected call of RoundTripOpt. -func (mr *MockSingleRoundTripperMockRecorder) RoundTripOpt(arg0, arg1 any) *MockSingleRoundTripperRoundTripOptCall { +// RoundTrip indicates an expected call of RoundTrip. +func (mr *MockSingleRoundTripperMockRecorder) RoundTrip(arg0 any) *MockSingleRoundTripperRoundTripCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RoundTripOpt", reflect.TypeOf((*MockSingleRoundTripper)(nil).RoundTripOpt), arg0, arg1) - return &MockSingleRoundTripperRoundTripOptCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RoundTrip", reflect.TypeOf((*MockSingleRoundTripper)(nil).RoundTrip), arg0) + return &MockSingleRoundTripperRoundTripCall{Call: call} } -// MockSingleRoundTripperRoundTripOptCall wrap *gomock.Call -type MockSingleRoundTripperRoundTripOptCall struct { +// MockSingleRoundTripperRoundTripCall wrap *gomock.Call +type MockSingleRoundTripperRoundTripCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockSingleRoundTripperRoundTripOptCall) Return(arg0 *http.Response, arg1 error) *MockSingleRoundTripperRoundTripOptCall { +func (c *MockSingleRoundTripperRoundTripCall) Return(arg0 *http.Response, arg1 error) *MockSingleRoundTripperRoundTripCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockSingleRoundTripperRoundTripOptCall) Do(f func(*http.Request, RoundTripOpt) (*http.Response, error)) *MockSingleRoundTripperRoundTripOptCall { +func (c *MockSingleRoundTripperRoundTripCall) Do(f func(*http.Request) (*http.Response, error)) *MockSingleRoundTripperRoundTripCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSingleRoundTripperRoundTripOptCall) DoAndReturn(f func(*http.Request, RoundTripOpt) (*http.Response, error)) *MockSingleRoundTripperRoundTripOptCall { +func (c *MockSingleRoundTripperRoundTripCall) DoAndReturn(f func(*http.Request) (*http.Response, error)) *MockSingleRoundTripperRoundTripCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/http3/roundtrip.go b/http3/roundtrip.go index f8b0e542..e2b9171c 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -33,15 +33,11 @@ type RoundTripOpt struct { // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. OnlyCachedConn bool - // CheckSettings is run before the request is sent to the server. - // If not yet received, it blocks until the server's SETTINGS frame is received. - // If an error is returned, the request won't be sent to the server, and the error is returned. - CheckSettings func(Settings) error } type singleRoundTripper interface { - RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) OpenRequestStream(context.Context) (RequestStream, error) + RoundTrip(*http.Request) (*http.Response, error) } type roundTripperWithCount struct { @@ -163,7 +159,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. return nil, cl.dialErr } defer cl.useCount.Add(-1) - rsp, err := cl.rt.RoundTripOpt(req, opt) + rsp, err := cl.rt.RoundTrip(req) if err != nil { // non-nil errors on roundtrip are likely due to a problem with the connection // so we remove the client from the cache so that subsequent trips reconnect diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index c314f273..b21f9313 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -249,8 +249,8 @@ var _ = Describe("RoundTripper", func() { close(handshakeChan) conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2) - cl.EXPECT().RoundTripOpt(req1, gomock.Any()).Return(&http.Response{Request: req1}, nil) - cl.EXPECT().RoundTripOpt(req2, gomock.Any()).Return(&http.Response{Request: req2}, nil) + cl.EXPECT().RoundTrip(req1).Return(&http.Response{Request: req1}, nil) + cl.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil) var count int rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { count++ @@ -286,8 +286,8 @@ var _ = Describe("RoundTripper", func() { handshakeChan := make(chan struct{}) close(handshakeChan) conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2) - cl1.EXPECT().RoundTripOpt(req1, gomock.Any()).Return(nil, testErr) - cl2.EXPECT().RoundTripOpt(req2, gomock.Any()).Return(&http.Response{Request: req2}, nil) + cl1.EXPECT().RoundTrip(req1).Return(nil, testErr) + cl2.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil) _, err = rt.RoundTrip(req1) Expect(err).To(MatchError(testErr)) rsp, err := rt.RoundTrip(req2) @@ -317,8 +317,8 @@ var _ = Describe("RoundTripper", func() { handshakeChan := make(chan struct{}) close(handshakeChan) conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2) - cl1.EXPECT().RoundTripOpt(req1, gomock.Any()).Return(nil, testErr) - cl1.EXPECT().RoundTripOpt(req2, gomock.Any()).Return(&http.Response{Request: req2}, nil) + cl1.EXPECT().RoundTrip(req1).Return(nil, testErr) + cl1.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil) _, err = rt.RoundTrip(req1) Expect(err).To(MatchError(testErr)) rsp, err := rt.RoundTrip(req2) @@ -330,7 +330,7 @@ var _ = Describe("RoundTripper", func() { It("recreates a client when a request times out", func() { var reqCount int cl1 := NewMockSingleRoundTripper(mockCtrl) - cl1.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { + cl1.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(req *http.Request) (*http.Response, error) { reqCount++ if reqCount == 1 { // the first request is successful... Expect(req.URL).To(Equal(req1.URL)) @@ -341,7 +341,7 @@ var _ = Describe("RoundTripper", func() { return nil, &qerr.IdleTimeoutError{} }).Times(2) cl2 := NewMockSingleRoundTripper(mockCtrl) - cl2.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { + cl2.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(req *http.Request) (*http.Response, error) { return &http.Response{Request: req}, nil }) clientChan <- cl1 @@ -372,7 +372,7 @@ var _ = Describe("RoundTripper", func() { } rt.newClient = func(quic.EarlyConnection) singleRoundTripper { cl := NewMockSingleRoundTripper(mockCtrl) - cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, &qerr.IdleTimeoutError{}) + cl.EXPECT().RoundTrip(gomock.Any()).Return(nil, &qerr.IdleTimeoutError{}) return cl } _, err := rt.RoundTrip(req1) @@ -385,7 +385,7 @@ var _ = Describe("RoundTripper", func() { reqs := make(chan struct{}, 2) cl := NewMockSingleRoundTripper(mockCtrl) - cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { + cl.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(req *http.Request) (*http.Response, error) { reqs <- struct{}{} <-wait return nil, &qerr.IdleTimeoutError{} @@ -503,7 +503,7 @@ var _ = Describe("RoundTripper", func() { }, newClient: func(quic.EarlyConnection) singleRoundTripper { cl := NewMockSingleRoundTripper(mockCtrl) - cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(&http.Response{}, nil) + cl.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{}, nil) return cl }, } @@ -544,7 +544,7 @@ var _ = Describe("RoundTripper", func() { reqFinished := make(chan struct{}) rt.newClient = func(quic.EarlyConnection) singleRoundTripper { cl := NewMockSingleRoundTripper(mockCtrl) - cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(r *http.Request, _ RoundTripOpt) (*http.Response, error) { + cl.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(r *http.Request) (*http.Response, error) { roundTripCalled <- struct{}{} <-r.Context().Done() return nil, nil diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 41bdc20f..44beaf95 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -587,16 +587,23 @@ var _ = Describe("HTTP tests", func() { }) It("checks the server's settings", func() { - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/hello", port), nil) + tlsConf := tlsClientConfigWithoutServerName.Clone() + tlsConf.NextProtos = []string{http3.NextProtoH3} + conn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", port), + tlsConf, + getQuicConfig(nil), + ) Expect(err).ToNot(HaveOccurred()) - testErr := errors.New("test error") - _, err = rt.RoundTripOpt(req, http3.RoundTripOpt{CheckSettings: func(settings http3.Settings) error { - Expect(settings.EnableExtendedConnect).To(BeTrue()) - Expect(settings.EnableDatagram).To(BeFalse()) - Expect(settings.Other).To(BeEmpty()) - return testErr - }}) - Expect(err).To(MatchError(err)) + defer conn.CloseWithError(0, "") + rt := http3.SingleDestinationRoundTripper{Connection: conn} + hconn := rt.Start() + Eventually(hconn.ReceivedSettings(), 5*time.Second, 10*time.Millisecond).Should(BeClosed()) + settings := hconn.Settings() + Expect(settings.EnableExtendedConnect).To(BeTrue()) + Expect(settings.EnableDatagram).To(BeFalse()) + Expect(settings.Other).To(BeEmpty()) }) It("receives the client's settings", func() { @@ -604,11 +611,7 @@ var _ = Describe("HTTP tests", func() { mux.HandleFunc("/settings", func(w http.ResponseWriter, r *http.Request) { defer GinkgoRecover() conn := w.(http3.Hijacker).Connection() - select { - case <-conn.ReceivedSettings(): - case <-time.After(5 * time.Second): - Fail("didn't receive SETTINGS") - } + Eventually(conn.ReceivedSettings(), 5*time.Second, 10*time.Millisecond).Should(BeClosed()) settingsChan <- conn.Settings() w.WriteHeader(http.StatusOK) })