diff --git a/http3/client.go b/http3/client.go index c990b353..35dd82e6 100644 --- a/http3/client.go +++ b/http3/client.go @@ -68,7 +68,9 @@ type client struct { logger utils.Logger } -func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) { +var _ roundTripCloser = &client{} + +func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { if conf == nil { conf = defaultQuicConfig.Clone() } else if len(conf.Versions) == 0 { @@ -434,3 +436,15 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, return res, requestError{} } + +func (c *client) HandshakeComplete() bool { + if c.conn == nil { + return false + } + select { + case <-c.conn.HandshakeComplete().Done(): + return true + default: + return false + } +} diff --git a/http3/client_test.go b/http3/client_test.go index a5b049c2..79ad281a 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -26,7 +26,7 @@ import ( var _ = Describe("Client", func() { var ( - client *client + cl *client req *http.Request origDialAddr = dialAddr handshakeCtx context.Context // an already canceled context @@ -35,10 +35,10 @@ var _ = Describe("Client", func() { BeforeEach(func() { origDialAddr = dialAddr hostname := "quic.clemente.io:1337" - var err error - client, err = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) + c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) Expect(err).ToNot(HaveOccurred()) - Expect(client.hostname).To(Equal(hostname)) + cl = c.(*client) + Expect(cl.hostname).To(Equal(hostname)) req, err = http.NewRequest("GET", "https://localhost:1337", nil) Expect(err).ToNot(HaveOccurred()) @@ -168,7 +168,7 @@ var _ = Describe("Client", func() { It("refuses to do requests for the wrong host", func() { req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) - _, err = client.RoundTripOpt(req, RoundTripOpt{}) + _, err = cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) }) @@ -179,7 +179,7 @@ var _ = Describe("Client", func() { dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return nil, testErr } - _, err = client.RoundTripOpt(req, RoundTripOpt{}) + _, err = cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) }) }) @@ -220,7 +220,7 @@ var _ = Describe("Client", func() { It("hijacks a bidirectional stream of unknown frame type", func() { frameTypeChan := make(chan FrameType, 1) - client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { + cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { Expect(e).ToNot(HaveOccurred()) frameTypeChan <- ft return true, nil @@ -235,7 +235,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := client.RoundTripOpt(request, RoundTripOpt{}) + _, err := cl.RoundTripOpt(request, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -243,7 +243,7 @@ var _ = Describe("Client", func() { It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { frameTypeChan := make(chan FrameType, 1) - client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { + cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { Expect(e).ToNot(HaveOccurred()) frameTypeChan <- ft return false, nil @@ -259,14 +259,14 @@ var _ = Describe("Client", func() { return nil, errors.New("test done") }) conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() - _, err := client.RoundTripOpt(request, RoundTripOpt{}) + _, err := cl.RoundTripOpt(request, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) }) It("closes the connection when hijacker returned error", func() { frameTypeChan := make(chan FrameType, 1) - client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { + cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { Expect(e).ToNot(HaveOccurred()) frameTypeChan <- ft return false, errors.New("error in hijacker") @@ -282,7 +282,7 @@ var _ = Describe("Client", func() { return nil, errors.New("test done") }) conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() - _, err := client.RoundTripOpt(request, RoundTripOpt{}) + _, err := cl.RoundTripOpt(request, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) }) @@ -291,7 +291,7 @@ var _ = Describe("Client", func() { testErr := errors.New("test error") unknownStr := mockquic.NewMockStream(mockCtrl) done := make(chan struct{}) - client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) { + cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) { defer close(done) Expect(e).To(MatchError(testErr)) Expect(ft).To(BeZero()) @@ -306,7 +306,7 @@ var _ = Describe("Client", func() { return nil, errors.New("test done") }) conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() - _, err := client.RoundTripOpt(request, RoundTripOpt{}) + _, err := cl.RoundTripOpt(request, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -348,7 +348,7 @@ var _ = Describe("Client", func() { It("hijacks an unidirectional stream of unknown stream type", func() { streamTypeChan := make(chan StreamType, 1) - client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { + cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { Expect(err).ToNot(HaveOccurred()) streamTypeChan <- st return true @@ -365,7 +365,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -375,7 +375,7 @@ var _ = Describe("Client", func() { testErr := errors.New("test error") done := make(chan struct{}) unknownStr := mockquic.NewMockStream(mockCtrl) - client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { + cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { defer close(done) Expect(st).To(BeZero()) Expect(str).To(Equal(unknownStr)) @@ -389,7 +389,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -397,7 +397,7 @@ var _ = Describe("Client", func() { It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { streamTypeChan := make(chan StreamType, 1) - client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { + cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { Expect(err).ToNot(HaveOccurred()) streamTypeChan <- st return false @@ -415,7 +415,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError @@ -467,7 +467,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -492,7 +492,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead }) @@ -515,7 +515,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) @@ -539,7 +539,7 @@ var _ = Describe("Client", func() { Expect(code).To(BeEquivalentTo(errorMissingSettings)) close(done) }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) @@ -563,7 +563,7 @@ var _ = Describe("Client", func() { Expect(code).To(BeEquivalentTo(errorFrameError)) close(done) }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) @@ -586,13 +586,13 @@ var _ = Describe("Client", func() { Expect(code).To(BeEquivalentTo(errorIDError)) close(done) }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) It("errors when the server advertises datagram support (and we enabled support for it)", func() { - client.opts.EnableDatagram = true + cl.opts.EnableDatagram = true b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{Datagram: true}).Append(b) r := bytes.NewReader(b) @@ -613,7 +613,7 @@ var _ = Describe("Client", func() { Expect(reason).To(Equal("missing QUIC Datagram support")) close(done) }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) @@ -705,7 +705,7 @@ var _ = Describe("Client", func() { conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) }) @@ -721,7 +721,7 @@ var _ = Describe("Client", func() { str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET")) }) @@ -736,7 +736,7 @@ var _ = Describe("Client", func() { str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Close() str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() - rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) + rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) Expect(rsp.Proto).To(Equal("HTTP/3.0")) Expect(rsp.ProtoMajor).To(Equal(3)) @@ -753,7 +753,7 @@ var _ = Describe("Client", func() { ) str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() - rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) + rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) Expect(err).ToNot(HaveOccurred()) Expect(rsp.Proto).To(Equal("HTTP/3.0")) Expect(rsp.ProtoMajor).To(Equal(3)) @@ -788,7 +788,7 @@ var _ = Describe("Client", func() { <-done return 0, errors.New("test done") }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) hfs := decodeHeader(strBuf) Expect(hfs).To(HaveKeyWithValue(":method", "POST")) @@ -812,7 +812,7 @@ var _ = Describe("Client", func() { }) closed := make(chan struct{}) str.EXPECT().Close().Do(func() { close(closed) }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) Eventually(closed).Should(BeClosed()) }) @@ -831,7 +831,7 @@ var _ = Describe("Client", func() { str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors // the response body is sent asynchronously, while already reading the response str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() - req, err := client.RoundTripOpt(req, RoundTripOpt{}) + req, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) Expect(req.ContentLength).To(BeEquivalentTo(1337)) Eventually(done).Should(BeClosed()) @@ -844,7 +844,7 @@ var _ = Describe("Client", func() { r := bytes.NewReader(b) str.EXPECT().Close().Do(func() { close(closed) }) str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("expected first frame to be a HEADERS frame")) Eventually(closed).Should(BeClosed()) }) @@ -856,7 +856,7 @@ var _ = Describe("Client", func() { closed := make(chan struct{}) str.EXPECT().Close().Do(func() { close(closed) }) str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)")) Eventually(closed).Should(BeClosed()) }) @@ -876,7 +876,7 @@ var _ = Describe("Client", func() { errChan := make(chan error) go func() { - _, err := client.RoundTripOpt(req, roundTripOpt) + _, err := cl.RoundTripOpt(req, roundTripOpt) errChan <- err }() Consistently(errChan).ShouldNot(Receive()) @@ -906,7 +906,7 @@ var _ = Describe("Client", func() { <-canceled return 0, errors.New("test done") }) - _, err := client.RoundTripOpt(req, roundTripOpt) + _, err := cl.RoundTripOpt(req, roundTripOpt) Expect(err).To(MatchError("test done")) Eventually(done).Should(BeClosed()) }) @@ -929,7 +929,7 @@ var _ = Describe("Client", func() { str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(done).Should(BeClosed()) @@ -950,7 +950,7 @@ var _ = Describe("Client", func() { str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors ) str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) hfs := decodeHeader(buf) Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) @@ -989,7 +989,7 @@ var _ = Describe("Client", func() { str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Close() - rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) + rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(rsp.Body) Expect(err).ToNot(HaveOccurred()) @@ -1012,7 +1012,7 @@ var _ = Describe("Client", func() { str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Close() - rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) + rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(rsp.Body) Expect(err).ToNot(HaveOccurred()) diff --git a/http3/mock_roundtripcloser_test.go b/http3/mock_roundtripcloser_test.go new file mode 100644 index 00000000..ce409a75 --- /dev/null +++ b/http3/mock_roundtripcloser_test.go @@ -0,0 +1,78 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: roundtrip.go + +// Package http3 is a generated GoMock package. +package http3 + +import ( + http "net/http" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockRoundTripCloser is a mock of RoundTripCloser interface. +type MockRoundTripCloser struct { + ctrl *gomock.Controller + recorder *MockRoundTripCloserMockRecorder +} + +// MockRoundTripCloserMockRecorder is the mock recorder for MockRoundTripCloser. +type MockRoundTripCloserMockRecorder struct { + mock *MockRoundTripCloser +} + +// NewMockRoundTripCloser creates a new mock instance. +func NewMockRoundTripCloser(ctrl *gomock.Controller) *MockRoundTripCloser { + mock := &MockRoundTripCloser{ctrl: ctrl} + mock.recorder = &MockRoundTripCloserMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRoundTripCloser) EXPECT() *MockRoundTripCloserMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRoundTripCloser) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRoundTripCloserMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRoundTripCloser)(nil).Close)) +} + +// HandshakeComplete mocks base method. +func (m *MockRoundTripCloser) HandshakeComplete() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandshakeComplete") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HandshakeComplete indicates an expected call of HandshakeComplete. +func (mr *MockRoundTripCloserMockRecorder) HandshakeComplete() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockRoundTripCloser)(nil).HandshakeComplete)) +} + +// RoundTripOpt mocks base method. +func (m *MockRoundTripCloser) RoundTripOpt(arg0 *http.Request, arg1 RoundTripOpt) (*http.Response, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RoundTripOpt", arg0, arg1) + ret0, _ := ret[0].(*http.Response) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RoundTripOpt indicates an expected call of RoundTripOpt. +func (mr *MockRoundTripCloserMockRecorder) RoundTripOpt(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RoundTripOpt", reflect.TypeOf((*MockRoundTripCloser)(nil).RoundTripOpt), arg0, arg1) +} diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 4cb41d9a..d9812abb 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "strings" "sync" @@ -17,6 +18,7 @@ import ( type roundTripCloser interface { RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) + HandshakeComplete() bool io.Closer } @@ -75,7 +77,8 @@ type RoundTripper struct { // Zero means to use a default limit. MaxResponseHeaderBytes int64 - clients map[string]roundTripCloser + newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests + clients map[string]roundTripCloser } // RoundTripOpt are options for the Transport.RoundTripOpt method. @@ -131,11 +134,20 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. } hostname := authorityAddr("https", hostnameFromRequest(req)) - cl, err := r.getClient(hostname, opt.OnlyCachedConn) + cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn) if err != nil { return nil, err } - return cl.RoundTripOpt(req, opt) + rsp, err := cl.RoundTripOpt(req, opt) + if err != nil { + r.removeClient(hostname) + if isReused { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + return r.RoundTripOpt(req, opt) + } + } + } + return rsp, err } // RoundTrip does a round trip. @@ -143,7 +155,7 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return r.RoundTripOpt(req, RoundTripOpt{}) } -func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripCloser, error) { +func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTripCloser, isReused bool, err error) { r.mutex.Lock() defer r.mutex.Unlock() @@ -154,10 +166,14 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripClo client, ok := r.clients[hostname] if !ok { if onlyCached { - return nil, ErrNoCachedConn + return nil, false, ErrNoCachedConn } var err error - client, err = newClient( + newCl := newClient + if r.newClient != nil { + newCl = r.newClient + } + client, err = newCl( hostname, r.TLSClientConfig, &roundTripperOpts{ @@ -171,11 +187,22 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripClo r.Dial, ) if err != nil { - return nil, err + return nil, false, err } r.clients[hostname] = client + } else if client.HandshakeComplete() { + isReused = true } - return client, nil + return client, isReused, nil +} + +func (r *RoundTripper) removeClient(hostname string) { + r.mutex.Lock() + defer r.mutex.Unlock() + if r.clients == nil { + return + } + delete(r.clients, hostname) } // Close closes the QUIC connections that this RoundTripper has used diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index 2a3fbd9b..b03eef60 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -10,27 +10,14 @@ import ( "time" "github.com/quic-go/quic-go" - mockquic "github.com/quic-go/quic-go/internal/mocks/quic" + "github.com/quic-go/quic-go/internal/qerr" "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) -type mockClient struct { - closed bool -} - -func (m *mockClient) RoundTripOpt(req *http.Request, _ RoundTripOpt) (*http.Response, error) { - return &http.Response{Request: req}, nil -} - -func (m *mockClient) Close() error { - m.closed = true - return nil -} - -var _ roundTripCloser = &mockClient{} +//go:generate sh -c "./../mockgen_private.sh http3 mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 roundTripCloser" type mockBody struct { reader bytes.Reader @@ -60,57 +47,29 @@ func (m *mockBody) Close() error { var _ = Describe("RoundTripper", func() { var ( - rt *RoundTripper - req1 *http.Request - conn *mockquic.MockEarlyConnection - handshakeCtx context.Context // an already canceled context + rt *RoundTripper + req *http.Request ) BeforeEach(func() { rt = &RoundTripper{} var err error - req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil) + req, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil) Expect(err).ToNot(HaveOccurred()) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - handshakeCtx = ctx }) Context("dialing hosts", func() { - origDialAddr := dialAddr - - BeforeEach(func() { - conn = mockquic.NewMockEarlyConnection(mockCtrl) - origDialAddr = dialAddr - dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { - // return an error when trying to open a stream - // we don't want to test all the dial logic here, just that dialing happens at all - return conn, nil - } - }) - - AfterEach(func() { - dialAddr = origDialAddr - }) - It("creates new clients", func() { - closed := make(chan struct{}) testErr := errors.New("test err") req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) - conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-closed - return nil, errors.New("test done") - }).MaxTimes(1) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) }) + rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { + cl := NewMockRoundTripCloser(mockCtrl) + cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr) + return cl, nil + } _, err = rt.RoundTrip(req) Expect(err).To(MatchError(testErr)) - Expect(rt.clients).To(HaveLen(1)) - Eventually(closed).Should(BeClosed()) }) It("uses the quic.Config, if provided", func() { @@ -121,7 +80,7 @@ var _ = Describe("RoundTripper", func() { return nil, errors.New("handshake error") } rt.QuicConfig = config - _, err := rt.RoundTrip(req1) + _, err := rt.RoundTrip(req) Expect(err).To(MatchError("handshake error")) Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout)) }) @@ -133,33 +92,144 @@ var _ = Describe("RoundTripper", func() { return nil, errors.New("handshake error") } rt.Dial = dialer - _, err := rt.RoundTrip(req1) + _, err := rt.RoundTrip(req) Expect(err).To(MatchError("handshake error")) Expect(dialed).To(BeTrue()) }) + }) + + Context("reusing clients", func() { + var req1, req2 *http.Request + + BeforeEach(func() { + var err error + req1, err = http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) + Expect(err).ToNot(HaveOccurred()) + req2, err = http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(req1.URL).ToNot(Equal(req2.URL)) + }) It("reuses existing clients", func() { - closed := make(chan struct{}) + var count int + rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { + count++ + cl := NewMockRoundTripCloser(mockCtrl) + cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { + return &http.Response{Request: req}, nil + }).Times(2) + cl.EXPECT().HandshakeComplete().Return(true) + return cl, nil + } + rsp1, err := rt.RoundTrip(req1) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp1.Request.URL).To(Equal(req1.URL)) + rsp2, err := rt.RoundTrip(req2) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp2.Request.URL).To(Equal(req2.URL)) + Expect(count).To(Equal(1)) + }) + + It("immediately removes a clients when a request errored", func() { testErr := errors.New("test err") - conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx).Times(2) - conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2) - conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - <-closed - return nil, errors.New("test done") - }).MaxTimes(1) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) }) - req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) - Expect(err).ToNot(HaveOccurred()) - _, err = rt.RoundTrip(req) + + var count int + rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { + count++ + cl := NewMockRoundTripCloser(mockCtrl) + cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr) + return cl, nil + } + _, err := rt.RoundTrip(req1) Expect(err).To(MatchError(testErr)) - Expect(rt.clients).To(HaveLen(1)) - req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil) - Expect(err).ToNot(HaveOccurred()) _, err = rt.RoundTrip(req2) Expect(err).To(MatchError(testErr)) - Expect(rt.clients).To(HaveLen(1)) - Eventually(closed).Should(BeClosed()) + Expect(count).To(Equal(2)) + }) + + It("recreates a client when a request times out", func() { + var reqCount int + cl1 := NewMockRoundTripCloser(mockCtrl) + cl1.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { + reqCount++ + if reqCount == 1 { // the first request is successful... + Expect(req.URL).To(Equal(req1.URL)) + return &http.Response{Request: req}, nil + } + // ... after that, the connection timed out in the background + Expect(req.URL).To(Equal(req2.URL)) + return nil, &qerr.IdleTimeoutError{} + }).Times(2) + cl1.EXPECT().HandshakeComplete().Return(true) + cl2 := NewMockRoundTripCloser(mockCtrl) + cl2.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { + return &http.Response{Request: req}, nil + }) + + var count int + rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { + count++ + if count == 1 { + return cl1, nil + } + return cl2, nil + } + rsp1, err := rt.RoundTrip(req1) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp1.Request.RemoteAddr).To(Equal(req1.RemoteAddr)) + rsp2, err := rt.RoundTrip(req2) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp2.Request.RemoteAddr).To(Equal(req2.RemoteAddr)) + }) + + It("only issues a request once, even if a timeout error occurs", func() { + var count int + rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { + count++ + cl := NewMockRoundTripCloser(mockCtrl) + cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, &qerr.IdleTimeoutError{}) + return cl, nil + } + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError(&qerr.IdleTimeoutError{})) + Expect(count).To(Equal(1)) + }) + + It("handles a burst of requests", func() { + wait := make(chan struct{}) + reqs := make(chan struct{}, 2) + var count int + rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { + count++ + cl := NewMockRoundTripCloser(mockCtrl) + cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { + reqs <- struct{}{} + <-wait + return nil, &qerr.IdleTimeoutError{} + }).Times(2) + cl.EXPECT().HandshakeComplete() + return cl, nil + } + done := make(chan struct{}, 2) + go func() { + defer GinkgoRecover() + defer func() { done <- struct{}{} }() + _, err := rt.RoundTrip(req1) + Expect(err).To(MatchError(&qerr.IdleTimeoutError{})) + }() + go func() { + defer GinkgoRecover() + defer func() { done <- struct{}{} }() + _, err := rt.RoundTrip(req2) + Expect(err).To(MatchError(&qerr.IdleTimeoutError{})) + }() + // wait for both requests to be issued + Eventually(reqs).Should(Receive()) + Eventually(reqs).Should(Receive()) + close(wait) // now return the requests + Eventually(done).Should(Receive()) + Eventually(done).Should(Receive()) + Expect(count).To(Equal(1)) }) It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() { @@ -181,66 +251,66 @@ var _ = Describe("RoundTripper", func() { }) It("rejects requests without a URL", func() { - req1.URL = nil - req1.Body = &mockBody{} - _, err := rt.RoundTrip(req1) + req.URL = nil + req.Body = &mockBody{} + _, err := rt.RoundTrip(req) Expect(err).To(MatchError("http3: nil Request.URL")) - Expect(req1.Body.(*mockBody).closed).To(BeTrue()) + Expect(req.Body.(*mockBody).closed).To(BeTrue()) }) It("rejects request without a URL Host", func() { - req1.URL.Host = "" - req1.Body = &mockBody{} - _, err := rt.RoundTrip(req1) + req.URL.Host = "" + req.Body = &mockBody{} + _, err := rt.RoundTrip(req) Expect(err).To(MatchError("http3: no Host in request URL")) - Expect(req1.Body.(*mockBody).closed).To(BeTrue()) + Expect(req.Body.(*mockBody).closed).To(BeTrue()) }) It("doesn't try to close the body if the request doesn't have one", func() { - req1.URL = nil - Expect(req1.Body).To(BeNil()) - _, err := rt.RoundTrip(req1) + req.URL = nil + Expect(req.Body).To(BeNil()) + _, err := rt.RoundTrip(req) Expect(err).To(MatchError("http3: nil Request.URL")) }) It("rejects requests without a header", func() { - req1.Header = nil - req1.Body = &mockBody{} - _, err := rt.RoundTrip(req1) + req.Header = nil + req.Body = &mockBody{} + _, err := rt.RoundTrip(req) Expect(err).To(MatchError("http3: nil Request.Header")) - Expect(req1.Body.(*mockBody).closed).To(BeTrue()) + Expect(req.Body.(*mockBody).closed).To(BeTrue()) }) It("rejects requests with invalid header name fields", func() { - req1.Header.Add("foobär", "value") - _, err := rt.RoundTrip(req1) + req.Header.Add("foobär", "value") + _, err := rt.RoundTrip(req) Expect(err).To(MatchError("http3: invalid http header field name \"foobär\"")) }) It("rejects requests with invalid header name values", func() { - req1.Header.Add("foo", string([]byte{0x7})) - _, err := rt.RoundTrip(req1) + req.Header.Add("foo", string([]byte{0x7})) + _, err := rt.RoundTrip(req) Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value")) }) It("rejects requests with an invalid request method", func() { - req1.Method = "foobär" - req1.Body = &mockBody{} - _, err := rt.RoundTrip(req1) + req.Method = "foobär" + req.Body = &mockBody{} + _, err := rt.RoundTrip(req) Expect(err).To(MatchError("http3: invalid method \"foobär\"")) - Expect(req1.Body.(*mockBody).closed).To(BeTrue()) + Expect(req.Body.(*mockBody).closed).To(BeTrue()) }) }) Context("closing", func() { It("closes", func() { rt.clients = make(map[string]roundTripCloser) - cl := &mockClient{} + cl := NewMockRoundTripCloser(mockCtrl) + cl.EXPECT().Close() rt.clients["foo.bar"] = cl err := rt.Close() Expect(err).ToNot(HaveOccurred()) Expect(len(rt.clients)).To(BeZero()) - Expect(cl.closed).To(BeTrue()) }) It("closes a RoundTripper that has never been used", func() {