diff --git a/http3/client.go b/http3/client.go index 1438cfe2..414db51f 100644 --- a/http3/client.go +++ b/http3/client.go @@ -32,6 +32,11 @@ const ( defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB ) +type errConnUnusable struct{ e error } + +func (e *errConnUnusable) Unwrap() error { return e.e } +func (e *errConnUnusable) Error() string { return fmt.Sprintf("http3: conn unusable: %s", e.e.Error()) } + const max1xxResponses = 5 // arbitrary bound on number of informational responses var defaultQuicConfig = &quic.Config{ @@ -230,7 +235,7 @@ func (c *ClientConn) roundTrip(req *http.Request) (*http.Response, error) { c.maxResponseHeaderBytes, ) if err != nil { - return nil, err + return nil, &errConnUnusable{e: err} } // Request Cancellation: diff --git a/http3/transport.go b/http3/transport.go index dd73a8bb..f1824e59 100644 --- a/http3/transport.go +++ b/http3/transport.go @@ -204,8 +204,12 @@ func (t *Transport) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res } } - trace := httptrace.ContextClientTrace(req.Context()) + return t.doRoundTripOpt(req, opt, false) +} + +func (t *Transport) doRoundTripOpt(req *http.Request, opt RoundTripOpt, isRetried bool) (*http.Response, error) { hostname := authorityAddr(hostnameFromURL(req.URL)) + trace := httptrace.ContextClientTrace(req.Context()) traceGetConn(trace, hostname) cl, isReused, err := t.getClient(req.Context(), hostname, opt.OnlyCachedConn) if err != nil { @@ -222,8 +226,8 @@ func (t *Transport) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res t.removeClient(hostname) return nil, cl.dialErr } - traceGotConn(trace, cl.conn, isReused) defer cl.useCount.Add(-1) + traceGotConn(trace, cl.conn, isReused) rsp, err := cl.clientConn.RoundTrip(req) if err != nil { // request aborted due to context cancellation @@ -232,28 +236,51 @@ func (t *Transport) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res return nil, err default: } - - // Retry the request on a new connection if: - // 1. it was sent on a reused connection, - // 2. this connection is now closed, - // 3. and the error is a timeout error. - select { - case <-cl.conn.Context().Done(): - t.removeClient(hostname) - if isReused { - var nerr net.Error - if errors.As(err, &nerr) && nerr.Timeout() { - return t.RoundTripOpt(req, opt) - } - } - return nil, err - default: + if isRetried { return nil, err } + + t.removeClient(hostname) + req, err = canRetryRequest(err, req) + if err != nil { + return nil, err + } + return t.doRoundTripOpt(req, opt, true) } return rsp, nil } +func canRetryRequest(err error, req *http.Request) (*http.Request, error) { + // error occurred while opening the stream, we can be sure that the request wasn't sent out + var connErr *errConnUnusable + if errors.As(err, &connErr) { + return req, nil + } + + // If the request stream is reset, we can only be sure that the request wasn't processed + // if the error code is H3_REQUEST_REJECTED. + var e *Error + if !errors.As(err, &e) || e.ErrorCode != ErrCodeRequestRejected { + return nil, err + } + // if the body is nil (or http.NoBody), it's safe to reuse this request and its body + if req.Body == nil || req.Body == http.NoBody { + return req, nil + } + // if the request body can be reset back to its original state via req.GetBody, do that + if req.GetBody != nil { + newBody, err := req.GetBody() + if err != nil { + return nil, err + } + reqCopy := *req + reqCopy.Body = newBody + req = &reqCopy + return &reqCopy, nil + } + return nil, fmt.Errorf("http3: Transport: cannot retry err [%w] after Request.Body was written; define Request.GetBody to avoid this error", err) +} + // RoundTrip does a round trip. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { return t.RoundTripOpt(req, RoundTripOpt{}) diff --git a/http3/transport_test.go b/http3/transport_test.go index 1be74e2b..1baa9f5d 100644 --- a/http3/transport_test.go +++ b/http3/transport_test.go @@ -5,16 +5,17 @@ import ( "context" "crypto/tls" "errors" + "fmt" "io" "net/http" "net/http/httptest" + "strings" "testing" "time" "github.com/quic-go/quic-go" mockquic "github.com/quic-go/quic-go/internal/mocks/quic" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/qerr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -257,28 +258,66 @@ func TestTransportConnectionReuse(t *testing.T) { // Requests reuse the same underlying QUIC connection. // If a request experiences an error, the behavior depends on the nature of that error. func TestTransportConnectionRedial(t *testing.T) { - // If it's connection error that is a timeout error, we re-dial a new connection. - // No error will be returned to the caller. - t.Run("timeout error", func(t *testing.T) { - testTransportConnectionRedial(t, true, &qerr.IdleTimeoutError{}, nil) + nonRetryableReq := httptest.NewRequest( + http.MethodGet, + "https://quic-go.org", + strings.NewReader("foobar"), + ) + require.Nil(t, nonRetryableReq.GetBody) + + retryableReq := nonRetryableReq.Clone(context.Background()) + retryableReq.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader("foobaz")), nil + } + + // If the error occurs when opening the stream, it is safe to retry the request: + // We can be certain that it wasn't sent out (not even partially). + t.Run("error when opening the stream", func(t *testing.T) { + require.NoError(t, + testTransportConnectionRedial(t, nonRetryableReq, &errConnUnusable{errors.New("test")}, "foobar", true), + ) }) - // If it's a different connection error, the error is returned to the caller. - // The connection is not redialed. - t.Run("other error from the connection", func(t *testing.T) { - testErr := &quic.TransportError{ErrorCode: quic.ConnectionIDLimitError} - testTransportConnectionRedial(t, true, testErr, testErr) + // If the error occurs when opening the stream, it is safe to retry the request: + // We can be certain that it wasn't sent out (not even partially). + t.Run("non-retryable request error after opening the stream", func(t *testing.T) { + require.ErrorIs(t, + testTransportConnectionRedial(t, nonRetryableReq, assert.AnError, "foobar", false), + assert.AnError, + ) }) - // If the error is not related to the connection, we return that error. - // The underlying connection remains open and is reused for subsequent requests. - t.Run("other error not from the connection", func(t *testing.T) { - testErr := &quic.TransportError{ErrorCode: quic.ConnectionIDLimitError} - testTransportConnectionRedial(t, false, testErr, testErr) + t.Run("retryable request after opening the stream", func(t *testing.T) { + require.ErrorIs(t, + testTransportConnectionRedial(t, retryableReq, assert.AnError, "", false), + assert.AnError, + ) + }) + + t.Run("retryable request after H3_REQUEST_REJECTED", func(t *testing.T) { + require.NoError(t, + testTransportConnectionRedial(t, + retryableReq, + &Error{ErrorCode: ErrCodeRequestRejected}, + "foobaz", + true, + ), + ) + }) + + t.Run("retryable request where GetBody returns an error", func(t *testing.T) { + req := nonRetryableReq.Clone(context.Background()) + req.GetBody = func() (io.ReadCloser, error) { + return nil, assert.AnError + } + require.ErrorIs(t, + testTransportConnectionRedial(t, req, &Error{ErrorCode: ErrCodeRequestRejected}, "", false), + assert.AnError, + ) }) } -func testTransportConnectionRedial(t *testing.T, connClosed bool, roundtripErr, expectedErr error) { +func testTransportConnectionRedial(t *testing.T, req *http.Request, roundtripErr error, expectedBody string, expectRedial bool) error { mockCtrl := gomock.NewController(t) cl := NewMockClientConn(mockCtrl) conn := mockquic.NewMockEarlyConnection(mockCtrl) @@ -294,47 +333,27 @@ func testTransportConnectionRedial(t *testing.T, connClosed bool, roundtripErr, newClientConn: func(quic.EarlyConnection) clientConn { return cl }, } - // the first request succeeds - req1 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file1.html", nil) - cl.EXPECT().RoundTrip(req1).Return(&http.Response{Request: req1}, nil) - rsp, err := tr.RoundTrip(req1) - require.NoError(t, err) - require.Equal(t, req1, rsp.Request) - require.Equal(t, 1, dialCount) + var body string + cl.EXPECT().RoundTrip(req).Return(nil, roundtripErr) + if expectRedial { + cl.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(r *http.Request) (*http.Response, error) { + b, err := io.ReadAll(r.Body) + if err != nil { + panic(fmt.Sprintf("reading body failed: %v", err)) + } + body = string(b) + return &http.Response{Request: req}, nil + }) + } - // the second request reuses the QUIC connection, and encounters an error - req2 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file2.html", nil) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - if connClosed { - cancel() - } - conn.EXPECT().Context().Return(ctx) - cl.EXPECT().RoundTrip(req2).Return(nil, roundtripErr) - if expectedErr == nil { - cl.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil) - } - rsp, err = tr.RoundTrip(req2) - if expectedErr == nil { - require.NoError(t, err) - require.Equal(t, req2, rsp.Request) - require.Equal(t, 2, dialCount) + _, err := tr.RoundTrip(req) + if !expectRedial { + assert.Equal(t, 1, dialCount) } else { - require.ErrorIs(t, err, expectedErr) - require.Equal(t, 1, dialCount) + assert.Equal(t, 2, dialCount) + assert.Equal(t, expectedBody, body) } - - // if the error was not a connection error, the next request reuses the connection - if connClosed { - return - } - currentDialCount := dialCount - req3 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file3.html", nil) - cl.EXPECT().RoundTrip(req3).Return(&http.Response{Request: req3}, nil) - rsp, err = tr.RoundTrip(req3) - require.NoError(t, err) - require.Equal(t, req3, rsp.Request) - require.Equal(t, currentDialCount, dialCount) // no new connection was dialed + return err } func TestTransportRequestContextCancellation(t *testing.T) { diff --git a/integrationtests/self/http_shutdown_test.go b/integrationtests/self/http_shutdown_test.go index 54a8136d..5aec5b57 100644 --- a/integrationtests/self/http_shutdown_test.go +++ b/integrationtests/self/http_shutdown_test.go @@ -2,6 +2,7 @@ package self_test import ( "context" + "crypto/tls" "fmt" "io" "net" @@ -142,7 +143,16 @@ func TestGracefulShutdownPendingStreams(t *testing.T) { }) var server *http3.Server port := startHTTPServer(t, mux, func(s *http3.Server) { server = s }) - client := newHTTP3Client(t) + connChan := make(chan quic.EarlyConnection, 1) + tr := &http3.Transport{ + TLSClientConfig: getTLSClientConfigWithoutServerName(), + Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + conn, err := quic.DialAddrEarly(ctx, addr, tlsCfg, cfg) + connChan <- conn + return conn, err + }, + } + cl := &http.Client{Transport: tr} proxy := quicproxy.Proxy{ Conn: newUDPConnLocalhost(t), @@ -159,7 +169,7 @@ func TestGracefulShutdownPendingStreams(t *testing.T) { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/helloworld", proxyPort), nil) require.NoError(t, err) go func() { - resp, err := client.Do(req) + resp, err := cl.Do(req) if err != nil { errChan <- err return @@ -180,14 +190,27 @@ func TestGracefulShutdownPendingStreams(t *testing.T) { go func() { shutdownChan <- server.Shutdown(ctx) }() time.Sleep(rtt / 2) // wait for the server to start shutting down + var conn quic.EarlyConnection + select { + case conn = <-connChan: + case <-time.After(time.Second): + t.Fatal("connection was not opened") + } + // make sure that the server rejects further requests for range 3 { - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/helloworld", proxyPort), nil) + str, err := conn.OpenStreamSync(ctx) require.NoError(t, err) - _, err = client.Do(req) - var h3err *http3.Error - require.ErrorAs(t, err, &h3err) - require.Equal(t, http3.ErrCodeRequestRejected, h3err.ErrorCode) + str.Write([]byte("foobar")) + select { + case <-str.Context().Done(): + case <-time.After(time.Second): + t.Fatal("stream was not rejected") + } + _, err = str.Read(make([]byte, 10)) + var serr *quic.StreamError + require.ErrorAs(t, err, &serr) + require.Equal(t, quic.StreamErrorCode(http3.ErrCodeRequestRejected), serr.ErrorCode) } cancel() diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 16065dce..875471ef 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -16,6 +16,7 @@ import ( "net/textproto" "os" "strconv" + "strings" "sync/atomic" "testing" "time" @@ -208,8 +209,8 @@ func TestHTTPMultipleRequests(t *testing.T) { func TestContentLengthForSmallResponse(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/small", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("foo")) - w.Write([]byte("bar")) + io.WriteString(w, "foo") + io.WriteString(w, "bar") }) port := startHTTPServer(t, mux) @@ -297,7 +298,7 @@ func TestHTTPErrAbortHandler(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/abort", func(w http.ResponseWriter, r *http.Request) { // no recover here as it will interfere with the handler - w.Write([]byte("foobar")) + io.WriteString(w, "foobar") w.(http.Flusher).Flush() // wait for the client to receive the response <-respChan @@ -374,19 +375,36 @@ func TestHTTPServerIdleTimeout(t *testing.T) { mux.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "Hello, World!\n") }) - port := startHTTPServer(t, mux, func(s *http3.Server) { s.IdleTimeout = 100 * time.Millisecond }) + idleTimeout := scaleDuration(10 * time.Millisecond) + port := startHTTPServer(t, mux, func(s *http3.Server) { s.IdleTimeout = idleTimeout }) + + connChan := make(chan quic.EarlyConnection, 1) + tr := &http3.Transport{ + TLSClientConfig: getTLSClientConfigWithoutServerName(), + Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + conn, err := quic.DialAddrEarly(ctx, addr, tlsCfg, cfg) + connChan <- conn + return conn, err + }, + } + t.Cleanup(func() { tr.Close() }) + cl := &http.Client{Transport: tr} - cl := newHTTP3Client(t) _, err := cl.Get(fmt.Sprintf("https://localhost:%d/hello", port)) require.NoError(t, err) - time.Sleep(150 * time.Millisecond) + var conn quic.EarlyConnection + select { + case conn = <-connChan: + case <-time.After(time.Second): + t.Fatal("connection was not opened") + } - _, err = cl.Get(fmt.Sprintf("https://localhost:%d/hello", port)) - require.Error(t, err) - var appErr *quic.ApplicationError - require.ErrorAs(t, err, &appErr) - require.Equal(t, quic.ApplicationErrorCode(http3.ErrCodeNoError), appErr.ErrorCode) + select { + case <-time.After(3 * idleTimeout): + t.Fatal("connection was not closed") + case <-conn.Context().Done(): + } } func TestHTTPReestablishConnectionAfterDialError(t *testing.T) { @@ -442,7 +460,7 @@ func TestHTTPClientRequestContextCancellation(t *testing.T) { mux.HandleFunc("/cancel-after", func(w http.ResponseWriter, r *http.Request) { // TODO(#4508): check for request context cancellations for { - if _, err := w.Write([]byte("foobar")); err != nil { + if _, err := io.WriteString(w, "foobar"); err != nil { errChan <- err return } @@ -495,7 +513,7 @@ func TestHTTPDeadlines(t *testing.T) { require.NoError(t, rc.SetReadDeadline(time.Now().Add(deadlineDelay))) body, err := io.ReadAll(r.Body) resultChan <- result{body: body, err: err} - w.Write([]byte("ok")) + io.WriteString(w, "ok") }) expectedEnd := time.Now().Add(deadlineDelay) @@ -728,7 +746,7 @@ func TestHTTPStreamedRequests(t *testing.T) { if err != nil { return } - if _, err := w.Write([]byte(msg)); err != nil { + if _, err := io.WriteString(w, msg); err != nil { errChan <- err return } @@ -773,7 +791,7 @@ func TestHTTP1xxResponse(t *testing.T) { w.Header().Add("Link", header1) w.Header().Add("Link", header2) w.WriteHeader(http.StatusEarlyHints) - w.Write([]byte(data)) + io.WriteString(w, data) w.WriteHeader(http.StatusOK) }) @@ -843,7 +861,7 @@ func TestHTTP1xxTerminalResponse(t *testing.T) { func TestHTTP0RTT(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/0rtt", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(strconv.FormatBool(!r.TLS.HandshakeComplete))) + io.WriteString(w, strconv.FormatBool(!r.TLS.HandshakeComplete)) }) port := startHTTPServer(t, mux) @@ -952,3 +970,132 @@ func TestHTTPStreamer(t *testing.T) { require.NoError(t, err) require.Equal(t, PRData, repl) } + +type blackHoleConn struct { + net.PacketConn + block atomic.Bool + close chan struct{} +} + +func (c *blackHoleConn) WriteTo(b []byte, addr net.Addr) (int, error) { + return c.PacketConn.WriteTo(b, addr) +} + +func (c *blackHoleConn) ReadFrom(b []byte) (int, net.Addr, error) { + if c.block.Load() { + <-c.close + return 0, nil, errors.New("blocked") + } + n, _, err := c.PacketConn.ReadFrom(b) + if c.block.Load() { + <-c.close + return 0, nil, errors.New("blocked") + } + return n, nil, err +} + +func (c *blackHoleConn) Close() error { + close(c.close) + return c.PacketConn.Close() +} + +func (c *blackHoleConn) StartBlocking() { c.block.Store(true) } + +func TestHTTPRequestRetryAfterIdleTimeout(t *testing.T) { + t.Run("only cached conn", func(t *testing.T) { + testHTTPRequestRetryAfterIdleTimeout(t, true) + }) + t.Run("allow re-dialing", func(t *testing.T) { + testHTTPRequestRetryAfterIdleTimeout(t, false) + }) +} + +func testHTTPRequestRetryAfterIdleTimeout(t *testing.T, onlyCachedConn bool) { + t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true") + + mux := http.NewServeMux() + mux.HandleFunc("/remote-addr", func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, r.RemoteAddr) + }) + port := startHTTPServer(t, mux, func(s *http3.Server) {}) + + firstConn := &blackHoleConn{PacketConn: newUDPConnLocalhost(t), close: make(chan struct{})} + secondConn := newUDPConnLocalhost(t) + conns := []net.PacketConn{firstConn, secondConn} + require.NotEqual(t, firstConn.LocalAddr().String(), secondConn.LocalAddr().String()) + + idleTimeout := scaleDuration(10 * time.Millisecond) + connChan := make(chan quic.EarlyConnection, 2) + tr := &http3.Transport{ + TLSClientConfig: getTLSClientConfigWithoutServerName(), + QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout}), + Dial: func(ctx context.Context, a string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + conn := conns[0] + conns = conns[1:] + addr, err := net.ResolveUDPAddr("udp", a) + if err != nil { + return nil, err + } + c, err := quic.DialEarly(ctx, conn, addr, tlsCfg, cfg) + if err != nil { + return nil, err + } + connChan <- c + return c, nil + }, + DisableCompression: true, + } + t.Cleanup(func() { tr.Close() }) + + var headersCount int + req, err := http.NewRequestWithContext( + httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + WroteHeaders: func() { headersCount++ }, + }), + http.MethodGet, + fmt.Sprintf("https://127.0.0.1:%d/remote-addr", port), + // Add a body (wrappped so that http.NewRequest doesn't set the GetBody callback), + // to make it impossible to retry this request. + // This tests that the detection logic works properly: + // If the request fails before the stream can be opened, it is always safe to retry. + io.LimitReader(strings.NewReader("foobar"), 1000), + ) + require.NoError(t, err) + + resp, err := tr.RoundTripOpt(req, http3.RoundTripOpt{}) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, firstConn.LocalAddr().String(), string(body)) + + firstConn.StartBlocking() + // wait for the connection to time out + select { + case c := <-connChan: + select { + case <-c.Context().Done(): + case <-time.After(time.Second): + t.Fatal("connection did not time out") + } + case <-time.After(time.Second): + t.Fatal("no connection was created") + } + + // second request should succeed after re-dialing + resp, err = tr.RoundTripOpt(req, http3.RoundTripOpt{OnlyCachedConn: onlyCachedConn}) + if onlyCachedConn { + require.EqualError(t, err, "http3: no cached connection was available") + require.Len(t, conns, 1) // no second dial attempt + require.Equal(t, 1, headersCount) + return + } + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + body, err = io.ReadAll(&readerWithTimeout{Reader: resp.Body, Timeout: 2 * time.Second}) + require.NoError(t, err) + require.Equal(t, secondConn.LocalAddr().String(), string(body)) + + require.Equal(t, 2, headersCount) + require.Empty(t, conns) // make sure we dialed 2 connections +}