Files
quic-go/http3/transport_test.go

552 lines
16 KiB
Go

package http3
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)
type mockBody struct {
reader bytes.Reader
readErr error
closeErr error
closed bool
}
// make sure the mockBody can be used as a http.Request.Body
var _ io.ReadCloser = &mockBody{}
func (m *mockBody) Read(p []byte) (int, error) {
if m.readErr != nil {
return 0, m.readErr
}
return m.reader.Read(p)
}
func (m *mockBody) SetData(data []byte) {
m.reader = *bytes.NewReader(data)
}
func (m *mockBody) Close() error {
m.closed = true
return m.closeErr
}
func TestRequestValidation(t *testing.T) {
var tr Transport
for _, tt := range []struct {
name string
req *http.Request
expectedErr string
expectedErrContains string
}{
{
name: "plain HTTP",
req: httptest.NewRequest(http.MethodGet, "http://www.example.org/", nil),
expectedErr: "http3: unsupported protocol scheme: http",
},
{
name: "missing URL",
req: func() *http.Request {
r := httptest.NewRequest(http.MethodGet, "https://www.example.org/", nil)
r.URL = nil
return r
}(),
expectedErr: "http3: nil Request.URL",
},
{
name: "missing URL Host",
req: func() *http.Request {
r := httptest.NewRequest(http.MethodGet, "https://www.example.org/", nil)
r.URL.Host = ""
return r
}(),
expectedErr: "http3: no Host in request URL",
},
{
name: "missing header",
req: func() *http.Request {
r := httptest.NewRequest(http.MethodGet, "https://www.example.org/", nil)
r.Header = nil
return r
}(),
expectedErr: "http3: nil Request.Header",
},
{
name: "invalid header name",
req: func() *http.Request {
r := httptest.NewRequest(http.MethodGet, "https://www.example.org/", nil)
r.Header.Add("foobär", "value")
return r
}(),
expectedErr: "http3: invalid http header field name \"foobär\"",
},
{
name: "invalid header value",
req: func() *http.Request {
r := httptest.NewRequest(http.MethodGet, "https://www.example.org/", nil)
r.Header.Add("foo", string([]byte{0x7}))
return r
}(),
expectedErrContains: "http3: invalid http header field value",
},
{
name: "invalid method",
req: func() *http.Request {
r := httptest.NewRequest(http.MethodGet, "https://www.example.org/", nil)
r.Method = "foobär"
return r
}(),
expectedErr: "http3: invalid method \"foobär\"",
},
} {
t.Run(tt.name, func(t *testing.T) {
tt.req.Body = &mockBody{}
_, err := tr.RoundTrip(tt.req)
if tt.expectedErr != "" {
require.EqualError(t, err, tt.expectedErr)
}
if tt.expectedErrContains != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.expectedErrContains)
}
require.True(t, tt.req.Body.(*mockBody).closed)
})
}
}
func TestTransportDialHostname(t *testing.T) {
type hostnameConfig struct {
dialHostname string
tlsServerName string
}
hostnameChan := make(chan hostnameConfig, 1)
tr := &Transport{
Dial: func(_ context.Context, hostname string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
hostnameChan <- hostnameConfig{
dialHostname: hostname,
tlsServerName: tlsConf.ServerName,
}
return nil, errors.New("test done")
},
}
t.Run("port set", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "https://quic-go.net:1234", nil)
_, err := tr.RoundTripOpt(req, RoundTripOpt{})
require.EqualError(t, err, "test done")
select {
case c := <-hostnameChan:
require.Equal(t, "quic-go.net:1234", c.dialHostname)
require.Equal(t, "quic-go.net", c.tlsServerName)
case <-time.After(1 * time.Second):
t.Fatal("timeout")
}
})
// if the request doesn't have a port, the default port is used
t.Run("port not set", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "https://quic-go.net", nil)
_, err := tr.RoundTripOpt(req, RoundTripOpt{})
require.EqualError(t, err, "test done")
select {
case c := <-hostnameChan:
require.Equal(t, "quic-go.net:443", c.dialHostname)
require.Equal(t, "quic-go.net", c.tlsServerName)
case <-time.After(1 * time.Second):
t.Fatal("timeout")
}
})
}
func TestTransportDatagrams(t *testing.T) {
// if the default quic.Config is used, the transport automatically enables QUIC datagrams
t.Run("default quic.Config", func(t *testing.T) {
tr := &Transport{
EnableDatagrams: true,
Dial: func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
require.True(t, quicConf.EnableDatagrams)
return nil, assert.AnError
},
}
req := httptest.NewRequest(http.MethodGet, "https://example.com", nil)
_, err := tr.RoundTripOpt(req, RoundTripOpt{})
require.ErrorIs(t, err, assert.AnError)
})
// if a custom quic.Config is used, the transport just checks that QUIC datagrams are enabled
t.Run("custom quic.Config", func(t *testing.T) {
tr := &Transport{
EnableDatagrams: true,
QUICConfig: &quic.Config{EnableDatagrams: false},
Dial: func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
t.Fatal("dial should not be called")
return nil, nil
},
}
req := httptest.NewRequest(http.MethodGet, "https://example.com", nil)
_, err := tr.RoundTripOpt(req, RoundTripOpt{})
require.EqualError(t, err, "HTTP Datagrams enabled, but QUIC Datagrams disabled")
})
}
func TestTransportMultipleQUICVersions(t *testing.T) {
qconf := &quic.Config{
Versions: []quic.Version{protocol.Version2, protocol.Version1},
}
tr := &Transport{QUICConfig: qconf}
req := httptest.NewRequest(http.MethodGet, "https://example.com", nil)
_, err := tr.RoundTrip(req)
require.EqualError(t, err, "can only use a single QUIC version for dialing a HTTP/3 connection")
}
func TestTransportConnectionReuse(t *testing.T) {
conn, _ := newConnPair(t)
mockCtrl := gomock.NewController(t)
cl := NewMockClientConn(mockCtrl)
var dialCount int
tr := &Transport{
Dial: func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
dialCount++
return conn, nil
},
newClientConn: func(quic.EarlyConnection) clientConn { return cl },
}
req1 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file1.html", nil)
// if OnlyCachedConn is set, no connection is dialed
_, err := tr.RoundTripOpt(req1, RoundTripOpt{OnlyCachedConn: true})
require.ErrorIs(t, err, ErrNoCachedConn)
require.Zero(t, dialCount)
// the first request establishes the connection...
cl.EXPECT().RoundTrip(req1).Return(&http.Response{Request: req1}, nil)
rsp, err := tr.RoundTrip(req1)
require.NoError(t, err)
require.Equal(t, req1, rsp.Request)
require.Equal(t, 1, dialCount)
// ... which is then used for the second request
req2 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file2.html", nil)
cl.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
rsp, err = tr.RoundTrip(req2)
require.NoError(t, err)
require.Equal(t, req2, rsp.Request)
require.Equal(t, 1, dialCount)
}
// Requests reuse the same underlying QUIC connection.
// If a request experiences an error, the behavior depends on the nature of that error.
func TestTransportConnectionRedial(t *testing.T) {
nonRetryableReq := httptest.NewRequest(
http.MethodGet,
"https://quic-go.org",
strings.NewReader("foobar"),
)
require.Nil(t, nonRetryableReq.GetBody)
retryableReq := nonRetryableReq.Clone(context.Background())
retryableReq.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader("foobaz")), nil
}
// If the error occurs when opening the stream, it is safe to retry the request:
// We can be certain that it wasn't sent out (not even partially).
t.Run("error when opening the stream", func(t *testing.T) {
require.NoError(t,
testTransportConnectionRedial(t, nonRetryableReq, &errConnUnusable{errors.New("test")}, "foobar", true),
)
})
// If the error occurs when opening the stream, it is safe to retry the request:
// We can be certain that it wasn't sent out (not even partially).
t.Run("non-retryable request error after opening the stream", func(t *testing.T) {
require.ErrorIs(t,
testTransportConnectionRedial(t, nonRetryableReq, assert.AnError, "foobar", false),
assert.AnError,
)
})
t.Run("retryable request after opening the stream", func(t *testing.T) {
require.ErrorIs(t,
testTransportConnectionRedial(t, retryableReq, assert.AnError, "", false),
assert.AnError,
)
})
t.Run("retryable request after H3_REQUEST_REJECTED", func(t *testing.T) {
require.NoError(t,
testTransportConnectionRedial(t,
retryableReq,
&Error{ErrorCode: ErrCodeRequestRejected},
"foobaz",
true,
),
)
})
t.Run("retryable request where GetBody returns an error", func(t *testing.T) {
req := nonRetryableReq.Clone(context.Background())
req.GetBody = func() (io.ReadCloser, error) {
return nil, assert.AnError
}
require.ErrorIs(t,
testTransportConnectionRedial(t, req, &Error{ErrorCode: ErrCodeRequestRejected}, "", false),
assert.AnError,
)
})
}
func testTransportConnectionRedial(t *testing.T, req *http.Request, roundtripErr error, expectedBody string, expectRedial bool) error {
conn, _ := newConnPair(t)
mockCtrl := gomock.NewController(t)
cl := NewMockClientConn(mockCtrl)
var dialCount int
tr := &Transport{
Dial: func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
dialCount++
return conn, nil
},
newClientConn: func(quic.EarlyConnection) clientConn { return cl },
}
var body string
cl.EXPECT().RoundTrip(req).Return(nil, roundtripErr)
if expectRedial {
cl.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(r *http.Request) (*http.Response, error) {
b, err := io.ReadAll(r.Body)
if err != nil {
panic(fmt.Sprintf("reading body failed: %v", err))
}
body = string(b)
return &http.Response{Request: req}, nil
})
}
_, err := tr.RoundTrip(req)
if !expectRedial {
assert.Equal(t, 1, dialCount)
} else {
assert.Equal(t, 2, dialCount)
assert.Equal(t, expectedBody, body)
}
return err
}
func TestTransportRequestContextCancellation(t *testing.T) {
mockCtrl := gomock.NewController(t)
cl := NewMockClientConn(mockCtrl)
conn, _ := newConnPair(t)
var dialCount int
tr := &Transport{
Dial: func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
dialCount++
return conn, nil
},
newClientConn: func(quic.EarlyConnection) clientConn { return cl },
}
// the first request succeeds
req1 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file1.html", nil)
cl.EXPECT().RoundTrip(req1).Return(&http.Response{Request: req1}, nil)
rsp, err := tr.RoundTrip(req1)
require.NoError(t, err)
require.Equal(t, req1, rsp.Request)
require.Equal(t, 1, dialCount)
// the second request reuses the QUIC connection, and runs into the cancelled context
req2 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file2.html", nil)
ctx, cancel := context.WithCancel(context.Background())
req2 = req2.WithContext(ctx)
cl.EXPECT().RoundTrip(req2).DoAndReturn(
func(r *http.Request) (*http.Response, error) {
cancel()
return nil, context.Canceled
},
)
_, err = tr.RoundTrip(req2)
require.ErrorIs(t, err, context.Canceled)
require.Equal(t, 1, dialCount)
// the next request reuses the QUIC connection
req3 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file2.html", nil)
cl.EXPECT().RoundTrip(req3).Return(&http.Response{Request: req3}, nil)
rsp, err = tr.RoundTrip(req3)
require.NoError(t, err)
require.Equal(t, req3, rsp.Request)
require.Equal(t, 1, dialCount)
}
func TestTransportConnetionRedialHandshakeError(t *testing.T) {
mockCtrl := gomock.NewController(t)
cl := NewMockClientConn(mockCtrl)
conn, _ := newConnPair(t)
var dialCount int
tr := &Transport{
Dial: func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
dialCount++
if dialCount == 1 {
return nil, assert.AnError
}
return conn, nil
},
newClientConn: func(quic.EarlyConnection) clientConn { return cl },
}
req1 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file1.html", nil)
_, err := tr.RoundTrip(req1)
require.ErrorIs(t, err, assert.AnError)
require.Equal(t, 1, dialCount)
req2 := httptest.NewRequest(http.MethodGet, "https://quic-go.net/file2.html", nil)
cl.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
rsp, err := tr.RoundTrip(req2)
require.NoError(t, err)
require.Equal(t, req2, rsp.Request)
require.Equal(t, 2, dialCount)
}
func TestTransportCloseEstablishedConnections(t *testing.T) {
mockCtrl := gomock.NewController(t)
conn, _ := newConnPair(t)
tr := &Transport{
Dial: func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
},
newClientConn: func(quic.EarlyConnection) clientConn {
cl := NewMockClientConn(mockCtrl)
cl.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{}, nil)
return cl
},
}
req := httptest.NewRequest(http.MethodGet, "https://quic-go.net/foobar.html", nil)
_, err := tr.RoundTrip(req)
require.NoError(t, err)
require.NoError(t, tr.Close())
select {
case <-conn.Context().Done():
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestTransportCloseInFlightDials(t *testing.T) {
tr := &Transport{
Dial: func(ctx context.Context, _ string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
var err error
select {
case <-ctx.Done():
err = ctx.Err()
case <-time.After(time.Second):
err = errors.New("timeout")
}
return nil, err
},
}
req := httptest.NewRequest(http.MethodGet, "https://quic-go.net/foobar.html", nil)
errChan := make(chan error, 1)
go func() {
_, err := tr.RoundTrip(req)
errChan <- err
}()
select {
case err := <-errChan:
t.Fatalf("received unexpected error: %v", err)
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
require.NoError(t, tr.Close())
select {
case err := <-errChan:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestTransportCloseIdleConnections(t *testing.T) {
mockCtrl := gomock.NewController(t)
conn1, _ := newConnPair(t)
conn2, _ := newConnPair(t)
roundTripCalled := make(chan struct{})
tr := &Transport{
Dial: func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
switch hostname {
case "site1.com:443":
return conn1, nil
case "site2.com:443":
return conn2, nil
default:
t.Fatal("unexpected hostname")
return nil, errors.New("unexpected hostname")
}
},
newClientConn: func(quic.EarlyConnection) clientConn {
cl := NewMockClientConn(mockCtrl)
cl.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(r *http.Request) (*http.Response, error) {
roundTripCalled <- struct{}{}
<-r.Context().Done()
return nil, nil
})
return cl
},
}
req1 := httptest.NewRequest(http.MethodGet, "https://site1.com", nil)
req2 := httptest.NewRequest(http.MethodGet, "https://site2.com", nil)
require.NotEqual(t, req1.Host, req2.Host)
ctx1, cancel1 := context.WithCancel(context.Background())
ctx2, cancel2 := context.WithCancel(context.Background())
req1 = req1.WithContext(ctx1)
req2 = req2.WithContext(ctx2)
reqFinished := make(chan struct{})
go func() {
tr.RoundTrip(req1)
reqFinished <- struct{}{}
}()
go func() {
tr.RoundTrip(req2)
reqFinished <- struct{}{}
}()
<-roundTripCalled
<-roundTripCalled
// Both two requests are started.
cancel1()
<-reqFinished
// req1 is finished
tr.CloseIdleConnections()
select {
case <-conn1.Context().Done():
case <-time.After(time.Second):
t.Fatal("timeout")
}
cancel2()
<-reqFinished
// all requests are finished
tr.CloseIdleConnections()
select {
case <-conn2.Context().Done():
case <-time.After(time.Second):
t.Fatal("timeout")
}
}