Files
quic-go/http3/server_test.go

1025 lines
33 KiB
Go

package http3
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"runtime"
"testing"
"time"
"github.com/quic-go/qpack"
"github.com/quic-go/quic-go"
mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/testdata"
"github.com/quic-go/quic-go/quicvarint"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)
func TestConfigureTLSConfig(t *testing.T) {
t.Run("basic config", func(t *testing.T) {
conf := ConfigureTLSConfig(&tls.Config{})
require.Equal(t, conf.NextProtos, []string{NextProtoH3})
})
t.Run("ALPN set", func(t *testing.T) {
conf := ConfigureTLSConfig(&tls.Config{NextProtos: []string{"foo", "bar"}})
require.Equal(t, []string{NextProtoH3}, conf.NextProtos)
})
// for configs that define GetConfigForClient, the ALPN is set to h3
t.Run("GetConfigForClient", func(t *testing.T) {
staticConf := &tls.Config{NextProtos: []string{"foo", "bar"}}
conf := ConfigureTLSConfig(&tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return staticConf, nil
},
})
innerConf, err := conf.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "example.com"})
require.NoError(t, err)
require.NotNil(t, innerConf)
require.Equal(t, []string{NextProtoH3}, innerConf.NextProtos)
// make sure the original config was not modified
require.Equal(t, []string{"foo", "bar"}, staticConf.NextProtos)
})
// GetConfigForClient might return a nil tls.Config
t.Run("GetConfigForClient returns nil", func(t *testing.T) {
conf := ConfigureTLSConfig(&tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return nil, nil
},
})
innerConf, err := conf.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "example.com"})
require.NoError(t, err)
require.Nil(t, innerConf)
})
}
func TestServerSettings(t *testing.T) {
t.Run("enable datagrams", func(t *testing.T) {
testServerSettings(t, true, nil)
})
t.Run("additional settings", func(t *testing.T) {
testServerSettings(t, false, map[uint64]uint64{13: 37})
})
}
func testServerSettings(t *testing.T, enableDatagrams bool, other map[uint64]uint64) {
s := Server{
EnableDatagrams: enableDatagrams,
AdditionalSettings: other,
}
s.init()
testDone := make(chan struct{})
defer close(testDone)
settingsChan := make(chan []byte)
mockCtrl := gomock.NewController(t)
conn := mockquic.NewMockEarlyConnection(mockCtrl)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
settingsChan <- b
return len(b), nil
})
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, assert.AnError
}).MaxTimes(1)
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(ctx context.Context) (quic.Stream, error) {
<-testDone
return nil, assert.AnError
}).MaxTimes(1)
conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
conn.EXPECT().LocalAddr().AnyTimes()
conn.EXPECT().Context().Return(context.Background()).AnyTimes()
go s.handleConn(conn)
select {
case b := <-settingsChan:
typ, l, err := quicvarint.Parse(b)
require.NoError(t, err)
require.EqualValues(t, streamTypeControlStream, typ)
fp := (&frameParser{r: bytes.NewReader(b[l:])})
f, err := fp.ParseNext()
require.NoError(t, err)
require.IsType(t, &settingsFrame{}, f)
settingsFrame := f.(*settingsFrame)
// Extended CONNECT is always supported
require.True(t, settingsFrame.ExtendedConnect)
require.Equal(t, settingsFrame.Datagram, enableDatagrams)
require.Equal(t, settingsFrame.Other, other)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func decodeHeader(t *testing.T, r io.Reader) map[string][]string {
fields := make(map[string][]string)
decoder := qpack.NewDecoder(nil)
frame, err := (&frameParser{r: r}).ParseNext()
require.NoError(t, err)
require.IsType(t, &headersFrame{}, frame)
headersFrame := frame.(*headersFrame)
data := make([]byte, headersFrame.Length)
_, err = io.ReadFull(r, data)
require.NoError(t, err)
hfs, err := decoder.DecodeFull(data)
require.NoError(t, err)
for _, p := range hfs {
fields[p.Name] = append(fields[p.Name], p.Value)
}
return fields
}
func TestServerRequestHandling(t *testing.T) {
t.Run("200 with an empty handler", func(t *testing.T) {
hfs, body := testServerRequestHandling(t,
func(w http.ResponseWriter, r *http.Request) {},
httptest.NewRequest(http.MethodGet, "https://www.example.com", nil),
)
require.Equal(t, hfs[":status"], []string{"200"})
require.Empty(t, body)
})
t.Run("content-length", func(t *testing.T) {
hfs, body := testServerRequestHandling(t,
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
w.Write([]byte("foobar"))
},
httptest.NewRequest(http.MethodGet, "https://www.example.com", nil),
)
require.Equal(t, hfs[":status"], []string{"418"})
require.Equal(t, hfs["content-length"], []string{"6"})
require.Equal(t, body, []byte("foobar"))
})
t.Run("no content-length when flushed", func(t *testing.T) {
hfs, body := testServerRequestHandling(t,
func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foo"))
w.(http.Flusher).Flush()
w.Write([]byte("bar"))
},
httptest.NewRequest(http.MethodGet, "https://www.example.com", nil),
)
require.Equal(t, hfs[":status"], []string{"200"})
require.NotContains(t, hfs, "content-length")
require.Equal(t, body, []byte("foobar"))
})
t.Run("HEAD request", func(t *testing.T) {
hfs, body := testServerRequestHandling(t,
func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foobar"))
},
httptest.NewRequest(http.MethodHead, "https://www.example.com", nil),
)
require.Equal(t, hfs[":status"], []string{"200"})
require.Empty(t, body)
})
t.Run("POST request", func(t *testing.T) {
hfs, body := testServerRequestHandling(t,
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
data, _ := io.ReadAll(r.Body)
w.Write(data)
},
httptest.NewRequest(http.MethodPost, "https://www.example.com", bytes.NewBuffer([]byte("foobar"))),
)
require.Equal(t, hfs[":status"], []string{"418"})
require.Equal(t, []byte("foobar"), body)
})
}
func encodeRequest(t *testing.T, req *http.Request) io.Reader {
var buf bytes.Buffer
rw := newRequestWriter()
require.NoError(t, rw.WriteRequestHeader(&buf, req, false))
if req.Body != nil {
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
buf.Write((&dataFrame{Length: uint64(len(body))}).Append(nil))
buf.Write(body)
}
return bytes.NewReader(buf.Bytes())
}
func testServerRequestHandling(t *testing.T,
handler http.HandlerFunc,
req *http.Request,
) (responseHeaders map[string][]string, body []byte) {
responseBuf := &bytes.Buffer{}
mockCtrl := gomock.NewController(t)
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Context().Return(context.Background()).AnyTimes()
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(encodeRequest(t, req).Read).AnyTimes()
s := &Server{
TLSConfig: testdata.GetTLSConfig(),
Handler: handler,
}
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
qconn.EXPECT().LocalAddr().AnyTimes()
qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
s.handleRequest(conn, str, nil, qpack.NewDecoder(nil))
hfs := decodeHeader(t, responseBuf)
fp := frameParser{r: responseBuf}
var content []byte
for {
frame, err := fp.ParseNext()
if err == io.EOF {
break
}
require.NoError(t, err)
require.IsType(t, &dataFrame{}, frame)
b := make([]byte, frame.(*dataFrame).Length)
_, err = io.ReadFull(responseBuf, b)
require.NoError(t, err)
content = append(content, b...)
}
return hfs, content
}
func TestServerHandlerBodyNotRead(t *testing.T) {
t.Run("GET request with a body", func(t *testing.T) {
testServerHandlerBodyNotRead(t,
httptest.NewRequest(http.MethodGet, "https://www.example.com", bytes.NewBuffer([]byte("foobar"))),
func(w http.ResponseWriter, r *http.Request) {},
)
})
t.Run("POST body not read", func(t *testing.T) {
testServerHandlerBodyNotRead(t,
httptest.NewRequest(http.MethodPost, "https://www.example.com", bytes.NewBuffer([]byte("foobar"))),
func(w http.ResponseWriter, r *http.Request) {},
)
})
t.Run("POST request, with a replaced body", func(t *testing.T) {
testServerHandlerBodyNotRead(t,
httptest.NewRequest(http.MethodPost, "https://www.example.com", bytes.NewBuffer([]byte("foobar"))),
func(w http.ResponseWriter, r *http.Request) {
r.Body = struct {
io.Reader
io.Closer
}{}
},
)
})
}
func TestServerFirstFrameNotHeaders(t *testing.T) {
mockCtrl := gomock.NewController(t)
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Context().Return(context.Background()).AnyTimes()
var buf bytes.Buffer
buf.Write((&dataFrame{Length: 6}).Append(nil))
buf.Write([]byte("foobar"))
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
s := &Server{TLSConfig: testdata.GetTLSConfig()}
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
qconn.EXPECT().LocalAddr().AnyTimes()
qconn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any())
conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
s.handleRequest(conn, str, nil, qpack.NewDecoder(nil))
}
func testServerHandlerBodyNotRead(t *testing.T, req *http.Request, handler http.HandlerFunc) {
mockCtrl := gomock.NewController(t)
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Context().Return(context.Background()).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
str.EXPECT().Close().MaxTimes(1)
str.EXPECT().Read(gomock.Any()).DoAndReturn(encodeRequest(t, req).Read).AnyTimes()
var called bool
s := &Server{
TLSConfig: testdata.GetTLSConfig(),
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
handler(w, r)
}),
}
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
qconn.EXPECT().LocalAddr().AnyTimes()
qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
s.handleRequest(conn, str, nil, qpack.NewDecoder(nil))
require.True(t, called)
}
func TestServerStreamResetByClient(t *testing.T) {
mockCtrl := gomock.NewController(t)
str := mockquic.NewMockStream(mockCtrl)
done := make(chan struct{})
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })
str.EXPECT().Read(gomock.Any()).Return(0, assert.AnError)
var called bool
s := &Server{
TLSConfig: testdata.GetTLSConfig(),
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
}),
}
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
qconn.EXPECT().LocalAddr().AnyTimes()
conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
s.handleRequest(conn, str, nil, qpack.NewDecoder(nil))
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.False(t, called)
}
func TestServerPanickingHandler(t *testing.T) {
t.Run("panicking handler", func(t *testing.T) {
logOutput := testServerPanickingHandler(t, func(w http.ResponseWriter, r *http.Request) {
panic("foobar")
})
require.Contains(t, logOutput, "http3: panic serving")
require.Contains(t, logOutput, "foobar")
})
t.Run("http.ErrAbortHandler", func(t *testing.T) {
logOutput := testServerPanickingHandler(t, func(w http.ResponseWriter, r *http.Request) {
panic(http.ErrAbortHandler)
})
require.NotContains(t, logOutput, "http3: panic serving")
require.NotContains(t, logOutput, "http.ErrAbortHandler")
})
}
func testServerPanickingHandler(t *testing.T, handler http.HandlerFunc) (logOutput string) {
mockCtrl := gomock.NewController(t)
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Context().Return(context.Background()).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
str.EXPECT().Read(gomock.Any()).DoAndReturn(
encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil)).Read,
).AnyTimes()
var logBuf bytes.Buffer
s := &Server{
TLSConfig: testdata.GetTLSConfig(),
Handler: handler,
Logger: slog.New(slog.NewTextHandler(&logBuf, nil)),
}
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
qconn.EXPECT().LocalAddr().AnyTimes()
qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
s.handleRequest(conn, str, nil, qpack.NewDecoder(nil))
return logBuf.String()
}
func TestServerRequestHeaderTooLarge(t *testing.T) {
t.Run("default value", func(t *testing.T) {
// use 2*DefaultMaxHeaderBytes here. qpack will compress the request,
// but the request will still end up larger than DefaultMaxHeaderBytes.
url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2)
testServerRequestHeaderTooLarge(t,
httptest.NewRequest(http.MethodGet, "https://"+string(url), nil),
0,
)
})
t.Run("custom value", func(t *testing.T) {
testServerRequestHeaderTooLarge(t,
httptest.NewRequest(http.MethodGet, "https://www.example.com", nil),
20,
)
})
}
func testServerRequestHeaderTooLarge(t *testing.T, req *http.Request, maxHeaderBytes int) {
var called bool
s := &Server{
TLSConfig: testdata.GetTLSConfig(),
MaxHeaderBytes: maxHeaderBytes,
Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true }),
}
s.init()
done := make(chan struct{}, 2)
mockCtrl := gomock.NewController(t)
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Context().Return(context.Background()).AnyTimes()
str.EXPECT().StreamID().AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { done <- struct{}{} })
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { done <- struct{}{} })
str.EXPECT().Read(gomock.Any()).DoAndReturn(encodeRequest(t, req).Read).AnyTimes()
testDone := make(chan struct{})
conn := mockquic.NewMockEarlyConnection(mockCtrl)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any())
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, assert.AnError
}).MaxTimes(1)
conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, assert.AnError)
conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
conn.EXPECT().LocalAddr().AnyTimes()
conn.EXPECT().Context().Return(context.Background()).AnyTimes()
s.handleConn(conn)
for range 2 {
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
require.False(t, called)
}
func TestServerRequestContext(t *testing.T) {
responseBuf := &bytes.Buffer{}
mockCtrl := gomock.NewController(t)
str := mockquic.NewMockStream(mockCtrl)
strCtx, strCtxCancel := context.WithCancel(context.Background())
str.EXPECT().StreamID().AnyTimes()
str.EXPECT().Context().Return(strCtx).AnyTimes()
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(
encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil)).Read,
).AnyTimes()
ctxChan := make(chan context.Context, 1)
s := &Server{
TLSConfig: testdata.GetTLSConfig(),
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxChan <- r.Context()
}),
}
s.init()
testDone := make(chan struct{})
defer close(testDone)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any()).AnyTimes()
conn := mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1337}).AnyTimes()
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(192, 168, 1, 2), Port: 42}).AnyTimes()
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
connCtx := context.WithValue(context.Background(), "connection context", "connection context value")
conn.EXPECT().Context().Return(connCtx).AnyTimes()
conn.EXPECT().OpenUniStream().Return(str, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, assert.AnError
}).MaxTimes(1)
conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
<-testDone
return nil, assert.AnError
}).MaxTimes(1)
go s.handleConn(conn)
var requestContext context.Context
select {
case requestContext = <-ctxChan:
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.Equal(t, "connection context value", requestContext.Value("connection context"))
require.Equal(t, s, requestContext.Value(ServerContextKey))
require.Equal(t, &net.UDPAddr{IP: net.IPv4(192, 168, 1, 2), Port: 42}, requestContext.Value(http.LocalAddrContextKey))
require.Equal(t, &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1337}, requestContext.Value(RemoteAddrContextKey))
select {
case <-requestContext.Done():
t.Fatal("request context was canceled")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
strCtxCancel()
select {
case <-requestContext.Done():
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.Equal(t, context.Canceled, requestContext.Err())
}
func TestServerHTTPStreamHijacking(t *testing.T) {
responseBuf := &bytes.Buffer{}
mockCtrl := gomock.NewController(t)
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Context().Return(context.Background()).AnyTimes()
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().Read(gomock.Any()).DoAndReturn(
encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil)).Read,
).AnyTimes()
s := &Server{
TLSConfig: testdata.GetTLSConfig(),
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.(HTTPStreamer).HTTPStream()
str.Write([]byte("foobar"))
}),
}
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
qconn.EXPECT().LocalAddr().AnyTimes()
qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
s.handleRequest(conn, str, nil, qpack.NewDecoder(nil))
hfs := decodeHeader(t, responseBuf)
require.Equal(t, hfs[":status"], []string{"200"})
require.Equal(t, []byte("foobar"), responseBuf.Bytes())
}
func TestServerStreamHijacking(t *testing.T) {
for _, bidirectional := range []bool{true, false} {
name := "bidirectional"
if !bidirectional {
name = "unidirectional"
}
t.Run(name, func(t *testing.T) {
t.Run("hijack", func(t *testing.T) {
testServerHijackBidirectionalStream(t, bidirectional, true, nil)
})
t.Run("don't hijack", func(t *testing.T) {
testServerHijackBidirectionalStream(t, bidirectional, false, nil)
})
t.Run("hijacker error", func(t *testing.T) {
testServerHijackBidirectionalStream(t, bidirectional, false, assert.AnError)
})
})
}
}
func testServerHijackBidirectionalStream(t *testing.T, bidirectional bool, doHijack bool, hijackErr error) {
id := quic.ConnectionTracingID(1337)
type hijackCall struct {
ft FrameType // for bidirectional streams
st StreamType // for unidirectional streams
connTracingID quic.ConnectionTracingID
e error
}
hijackChan := make(chan hijackCall, 1)
testDone := make(chan struct{})
s := &Server{
TLSConfig: testdata.GetTLSConfig(),
StreamHijacker: func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
defer close(testDone)
hijackChan <- hijackCall{ft: ft, connTracingID: connTracingID, e: e}
return doHijack, hijackErr
},
UniStreamHijacker: func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
defer close(testDone)
hijackChan <- hijackCall{st: st, connTracingID: connTracingID, e: err}
return doHijack
},
}
s.init()
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
mockCtrl := gomock.NewController(t)
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes()
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
unknownStr.EXPECT().StreamID().AnyTimes()
if !doHijack || hijackErr != nil {
if bidirectional {
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
} else {
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
}
}
conn := mockquic.NewMockEarlyConnection(mockCtrl)
if bidirectional {
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
} else {
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
}
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
<-testDone
return nil, assert.AnError
})
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any())
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
conn.EXPECT().LocalAddr().AnyTimes()
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, assert.AnError
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
select {
case hijackCall := <-hijackChan:
if bidirectional {
assert.Zero(t, hijackCall.st)
assert.Equal(t, hijackCall.ft, FrameType(0x41))
} else {
assert.Equal(t, hijackCall.st, StreamType(0x41))
assert.Zero(t, hijackCall.ft)
}
assert.Equal(t, hijackCall.connTracingID, id)
assert.NoError(t, hijackCall.e)
case <-time.After(time.Second):
t.Fatal("hijack call not received")
}
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
}
func getAltSvc(s *Server) (string, bool) {
hdr := http.Header{}
s.SetQUICHeaders(hdr)
if altSvc, ok := hdr["Alt-Svc"]; ok {
return altSvc[0], true
}
return "", false
}
func TestServerAltSvcFromListenersAndConns(t *testing.T) {
t.Run("default", func(t *testing.T) {
testServerAltSvcFromListenersAndConns(t, []quic.Version{})
})
t.Run("v1", func(t *testing.T) {
testServerAltSvcFromListenersAndConns(t, []quic.Version{quic.Version1})
})
t.Run("v1 and v2", func(t *testing.T) {
testServerAltSvcFromListenersAndConns(t, []quic.Version{quic.Version1, quic.Version2})
})
}
func testServerAltSvcFromListenersAndConns(t *testing.T, versions []quic.Version) {
ln1, err := quic.ListenEarly(newUDPConnLocalhost(t), testdata.GetTLSConfig(), nil)
require.NoError(t, err)
port1 := ln1.Addr().(*net.UDPAddr).Port
s := &Server{
Addr: ":1337", // will be ignored since we're using listeners
TLSConfig: testdata.GetTLSConfig(),
QUICConfig: &quic.Config{Versions: versions},
}
done1 := make(chan struct{})
go func() {
defer close(done1)
s.ServeListener(ln1)
}()
time.Sleep(scaleDuration(10 * time.Millisecond))
altSvc, ok := getAltSvc(s)
require.True(t, ok)
require.Equal(t, fmt.Sprintf(`h3=":%d"; ma=2592000`, port1), altSvc)
udpConn := newUDPConnLocalhost(t)
port2 := udpConn.LocalAddr().(*net.UDPAddr).Port
done2 := make(chan struct{})
go func() {
defer close(done2)
s.Serve(udpConn)
}()
time.Sleep(scaleDuration(10 * time.Millisecond))
altSvc, ok = getAltSvc(s)
require.True(t, ok)
require.Equal(t, fmt.Sprintf(`h3=":%d"; ma=2592000,h3=":%d"; ma=2592000`, port1, port2), altSvc)
// Close the first listener.
// This should remove the associated Alt-Svc entry.
require.NoError(t, ln1.Close())
select {
case <-done1:
case <-time.After(time.Second):
t.Fatal("timeout")
}
altSvc, ok = getAltSvc(s)
require.True(t, ok)
require.Equal(t, fmt.Sprintf(`h3=":%d"; ma=2592000`, port2), altSvc)
// Close the second listener.
// This should remove the Alt-Svc entry altogether.
require.NoError(t, udpConn.Close())
select {
case <-done2:
case <-time.After(time.Second):
t.Fatal("timeout")
}
_, ok = getAltSvc(s)
require.False(t, ok)
}
func TestServerAltSvcFromPort(t *testing.T) {
s := &Server{Port: 1337}
_, ok := getAltSvc(s)
require.False(t, ok)
ln, err := quic.ListenEarly(newUDPConnLocalhost(t), testdata.GetTLSConfig(), nil)
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
s.ServeListener(ln)
}()
time.Sleep(scaleDuration(10 * time.Millisecond))
altSvc, ok := getAltSvc(s)
require.True(t, ok)
require.Equal(t, `h3=":1337"; ma=2592000`, altSvc)
require.NoError(t, ln.Close())
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
_, ok = getAltSvc(s)
require.False(t, ok)
}
type unixSocketListener struct {
*quic.EarlyListener
}
func (l *unixSocketListener) Addr() net.Addr {
return &net.UnixAddr{Net: "unix", Name: "/tmp/quic.sock"}
}
func TestServerAltSvcFromUnixSocket(t *testing.T) {
t.Run("with Server.Addr not set", func(t *testing.T) {
_, ok := testServerAltSvcFromUnixSocket(t, "")
require.False(t, ok)
})
t.Run("with Server.Addr set", func(t *testing.T) {
altSvc, ok := testServerAltSvcFromUnixSocket(t, ":1337")
require.True(t, ok)
require.Equal(t, `h3=":1337"; ma=2592000`, altSvc)
})
}
func testServerAltSvcFromUnixSocket(t *testing.T, addr string) (altSvc string, ok bool) {
ln, err := quic.ListenEarly(newUDPConnLocalhost(t), testdata.GetTLSConfig(), nil)
require.NoError(t, err)
var logBuf bytes.Buffer
s := &Server{
Addr: addr,
Logger: slog.New(slog.NewTextHandler(&logBuf, nil)),
}
done := make(chan struct{})
go func() {
defer close(done)
s.ServeListener(&unixSocketListener{EarlyListener: ln})
}()
time.Sleep(scaleDuration(10 * time.Millisecond))
altSvc, ok = getAltSvc(s)
require.NoError(t, ln.Close())
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.Contains(t, logBuf.String(), "Unable to extract port from listener, will not be announced using SetQUICHeaders")
return altSvc, ok
}
func TestServerListenAndServeErrors(t *testing.T) {
require.EqualError(t, (&Server{}).ListenAndServe(), "use of http3.Server without TLSConfig")
s := &Server{
Addr: ":123456",
TLSConfig: testdata.GetTLSConfig(),
}
require.ErrorContains(t, s.ListenAndServe(), "invalid port")
}
func TestServerClosing(t *testing.T) {
s := &Server{TLSConfig: testdata.GetTLSConfig()}
require.NoError(t, s.Close())
require.NoError(t, s.Close()) // duplicate calls are ok
require.ErrorIs(t, s.ListenAndServe(), http.ErrServerClosed)
require.ErrorIs(t, s.ListenAndServeTLS(testdata.GetCertificatePaths()), http.ErrServerClosed)
require.ErrorIs(t, s.Serve(nil), http.ErrServerClosed)
require.ErrorIs(t, s.ServeListener(nil), http.ErrServerClosed)
require.ErrorIs(t, s.ServeQUICConn(nil), http.ErrServerClosed)
}
func TestServerConcurrentServeAndClose(t *testing.T) {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
require.NoError(t, err)
c, err := net.ListenUDP("udp", addr)
require.NoError(t, err)
done := make(chan struct{})
s := &Server{TLSConfig: testdata.GetTLSConfig()}
go func() {
defer close(done)
s.Serve(c)
}()
runtime.Gosched()
s.Close()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerImmediateGracefulShutdown(t *testing.T) {
s := &Server{TLSConfig: testdata.GetTLSConfig()}
errChan := make(chan error, 1)
go func() { errChan <- s.Shutdown(context.Background()) }()
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerGracefulShutdown(t *testing.T) {
s := &Server{TLSConfig: testdata.GetTLSConfig()}
s.init()
mockCtrl := gomock.NewController(t)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStrChan := make(chan []byte, 1)
controlStr.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
controlStrChan <- b
return len(b), nil
}).AnyTimes()
streamChan := make(chan quic.Stream, 1)
testDone := make(chan struct{})
conn := mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, assert.AnError
}).MaxTimes(1)
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(ctx context.Context) (quic.Stream, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case str, ok := <-streamChan:
if !ok {
return nil, assert.AnError
}
return str, nil
}
}).AnyTimes()
conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
conn.EXPECT().LocalAddr().AnyTimes()
conn.EXPECT().Context().Return(context.Background()).AnyTimes()
firstStream := mockquic.NewMockStream(mockCtrl)
firstStreamAccepted := make(chan struct{}, 1)
firstStream.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
firstStreamAccepted <- struct{}{}
<-testDone
return 0, assert.AnError
})
firstStream.EXPECT().StreamID().Return(quic.StreamID(1337)).AnyTimes()
firstStream.EXPECT().Context().Return(context.Background()).AnyTimes()
streamChan <- firstStream
go s.ServeQUICConn(conn)
var r bytes.Buffer
fp := &frameParser{r: &r}
select {
case b := <-controlStrChan:
_, l, err := quicvarint.Parse(b)
require.NoError(t, err)
r.Write(b[l:])
f, err := fp.ParseNext()
require.NoError(t, err)
require.IsType(t, &settingsFrame{}, f)
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case <-firstStreamAccepted:
case <-time.After(time.Second):
t.Fatal("timeout")
}
time.Sleep(scaleDuration(10 * time.Millisecond))
ctx, cancel := context.WithCancel(context.Background())
errChan := make(chan error)
go func() {
errChan <- s.Shutdown(ctx)
}()
select {
case b := <-controlStrChan:
r.Write(b)
f, err := fp.ParseNext()
require.NoError(t, err)
require.Equal(t, &goAwayFrame{StreamID: 1337 + 4}, f)
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case <-errChan:
t.Fatal("didn't expect Shutdown to return")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
// all further streams are getting rejected
for range 3 {
resetChan := make(chan struct{}, 2)
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().StreamID().AnyTimes()
str.EXPECT().Context().Return(context.Background()).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestRejected)).Do(func(sec quic.StreamErrorCode) {
resetChan <- struct{}{}
})
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestRejected)).Do(func(sec quic.StreamErrorCode) {
resetChan <- struct{}{}
})
streamChan <- str
for range 2 {
select {
case <-resetChan:
case <-time.After(time.Second):
t.Fatal("expected stream reset")
}
}
}
// cancel the context passed to Shutdown
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), gomock.Any())
cancel()
firstStream.EXPECT().CancelRead(gomock.Any())
firstStream.EXPECT().CancelWrite(gomock.Any())
close(testDone)
select {
case err := <-errChan:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}