Files
quic-go/server_test.go
2025-06-28 15:19:28 +02:00

1293 lines
39 KiB
Go

package quic
import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"net"
"slices"
"testing"
"time"
"github.com/quic-go/quic-go/internal/handshake"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)
type testServer struct{ *baseServer }
type serverOpts struct {
tracer *logging.Tracer
config *Config
tokenGeneratorKey TokenGeneratorKey
maxTokenAge time.Duration
useRetry bool
disableVersionNegotiation bool
acceptEarly bool
newConn func(
context.Context,
context.CancelCauseFunc,
sendConn,
connRunner,
protocol.ConnectionID, // original dest connection ID
*protocol.ConnectionID, // retry src connection ID
protocol.ConnectionID, // client dest connection ID
protocol.ConnectionID, // destination connection ID
protocol.ConnectionID, // source connection ID
ConnectionIDGenerator,
*statelessResetter,
*Config,
*tls.Config,
*handshake.TokenGenerator,
bool, /* client address validated by an address validation token */
time.Duration,
*logging.ConnectionTracer,
utils.Logger,
protocol.Version,
) *wrappedConn
}
func newTestServer(t *testing.T, serverOpts *serverOpts) *testServer {
t.Helper()
c, err := wrapConn(newUDPConnLocalhost(t))
require.NoError(t, err)
verifySourceAddress := func(net.Addr) bool { return serverOpts.useRetry }
config := populateConfig(serverOpts.config)
tr := &Transport{Conn: newUDPConnLocalhost(t)}
tr.init(true)
s := newServer(
c,
(*packetHandlerMap)(tr),
&protocol.DefaultConnectionIDGenerator{},
&statelessResetter{},
func(ctx context.Context, _ *ClientInfo) (context.Context, error) { return ctx, nil },
&tls.Config{},
config,
serverOpts.tracer,
func() {},
serverOpts.tokenGeneratorKey,
serverOpts.maxTokenAge,
verifySourceAddress,
serverOpts.disableVersionNegotiation,
serverOpts.acceptEarly,
)
s.newConn = serverOpts.newConn
t.Cleanup(func() { s.Close() })
return &testServer{s}
}
func getLongHeaderPacketEncrypted(t *testing.T, remoteAddr net.Addr, extHdr *wire.ExtendedHeader, data []byte) receivedPacket {
t.Helper()
hdr := extHdr.Header
if hdr.Type != protocol.PacketTypeInitial {
t.Fatal("can only encrypt Initial packets")
}
p := getLongHeaderPacket(t, remoteAddr, extHdr, data)
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version)
n := len(p.data) - len(data) // length of the header
p.data = slices.Grow(p.data, 16)
_ = sealer.Seal(p.data[n:n], p.data[n:], extHdr.PacketNumber, p.data[:n])
p.data = p.data[:len(p.data)+16]
sealer.EncryptHeader(p.data[n:n+16], &p.data[0], p.data[n-int(extHdr.PacketNumberLen):n])
return p
}
func randConnID(l int) protocol.ConnectionID {
b := make([]byte, l)
rand.Read(b)
return protocol.ParseConnectionID(b)
}
func getValidInitialPacket(t *testing.T, raddr net.Addr, srcConnID, destConnID protocol.ConnectionID) receivedPacket {
t.Helper()
return getLongHeaderPacket(t,
raddr,
&wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: srcConnID,
DestConnectionID: destConnID,
Length: protocol.MinInitialPacketSize,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen4,
},
make([]byte, protocol.MinInitialPacketSize),
)
}
type sentPacketCallArgs struct {
hdr *logging.Header
frames []logging.Frame
}
// checkConnectionClose checks
// 1. the arguments of the SentPacket tracer call, and
// 2. reads and parses the packet sent by the server
func checkConnectionClose(
t *testing.T,
conn *net.UDPConn,
argChan <-chan sentPacketCallArgs,
expectedSrcConnID protocol.ConnectionID,
expectedDestConnID protocol.ConnectionID,
expectedErrorCode qerr.TransportErrorCode,
) {
t.Helper()
select {
case args := <-argChan:
require.Equal(t, protocol.PacketTypeInitial, args.hdr.Type)
require.Equal(t, expectedSrcConnID, args.hdr.SrcConnectionID)
require.Equal(t, expectedDestConnID, args.hdr.DestConnectionID)
require.Len(t, args.frames, 1)
require.IsType(t, &wire.ConnectionCloseFrame{}, args.frames[0])
ccf := args.frames[0].(*logging.ConnectionCloseFrame)
require.EqualValues(t, expectedErrorCode, ccf.ErrorCode)
case <-time.After(time.Second):
t.Fatal("timeout")
}
conn.SetReadDeadline(time.Now().Add(time.Second))
b := make([]byte, 1500)
n, _, err := conn.ReadFromUDP(b)
require.NoError(t, err)
parsedHdr, _, _, err := wire.ParsePacket(b[:n])
require.NoError(t, err)
require.Equal(t, protocol.PacketTypeInitial, parsedHdr.Type)
require.Equal(t, expectedSrcConnID, parsedHdr.SrcConnectionID)
require.Equal(t, expectedDestConnID, parsedHdr.DestConnectionID)
}
func checkRetry(t *testing.T,
conn *net.UDPConn,
argChan <-chan sentPacketCallArgs,
expectedDestConnID protocol.ConnectionID,
) {
t.Helper()
select {
case args := <-argChan:
require.Equal(t, protocol.PacketTypeRetry, args.hdr.Type)
require.Equal(t, expectedDestConnID, args.hdr.DestConnectionID)
require.NotNil(t, args.hdr.Token)
require.Empty(t, args.frames)
case <-time.After(time.Second):
t.Fatal("timeout")
}
conn.SetReadDeadline(time.Now().Add(time.Second))
b := make([]byte, 1500)
n, _, err := conn.ReadFromUDP(b)
require.NoError(t, err)
parsedHdr, _, _, err := wire.ParsePacket(b[:n])
require.NoError(t, err)
require.Equal(t, protocol.PacketTypeRetry, parsedHdr.Type)
require.Equal(t, expectedDestConnID, parsedHdr.DestConnectionID)
require.NotNil(t, parsedHdr.Token)
}
func TestListen(t *testing.T) {
_, err := ListenAddr("localhost:0", nil, nil)
require.Error(t, err)
require.Contains(t, err.Error(), "quic: tls.Config not set")
_, err = Listen(nil, &tls.Config{}, &Config{Versions: []protocol.Version{0x1234}})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid QUIC version: 0x1234")
}
func TestListenAddr(t *testing.T) {
_, err := ListenAddr("127.0.0.1", &tls.Config{}, &Config{})
require.Error(t, err)
require.IsType(t, &net.AddrError{}, err)
_, err = ListenAddr("1.1.1.1:1111", &tls.Config{}, &Config{})
require.Error(t, err)
require.IsType(t, &net.OpError{}, err)
ln, err := ListenAddr("127.0.0.1:0", &tls.Config{}, &Config{})
require.NoError(t, err)
defer ln.Close()
}
func TestServerPacketDropping(t *testing.T) {
t.Run("destination connection ID too short", func(t *testing.T) {
conn := newUDPConnLocalhost(t)
testServerDroppedPacket(t,
conn,
getValidInitialPacket(t, conn.LocalAddr(), randConnID(5), randConnID(7)),
logging.PacketTypeInitial,
logging.PacketDropUnexpectedPacket,
)
})
t.Run("Initial packet too small", func(t *testing.T) {
conn := newUDPConnLocalhost(t)
p := getLongHeaderPacket(t,
conn.LocalAddr(),
&wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: randConnID(8),
Version: protocol.Version1,
},
PacketNumberLen: 2,
},
make([]byte, protocol.MinInitialPacketSize-100),
)
require.Greater(t, len(p.data), protocol.MinInitialPacketSize-100)
require.Less(t, len(p.data), protocol.MinInitialPacketSize)
testServerDroppedPacket(t,
conn,
p,
logging.PacketTypeInitial,
logging.PacketDropUnexpectedPacket,
)
})
// we should not send a Version Negotiation packet if the packet is smaller than 1200 bytes
t.Run("packet of unknown version, too small", func(t *testing.T) {
conn := newUDPConnLocalhost(t)
p := getLongHeaderPacket(t,
conn.LocalAddr(),
&wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: randConnID(8),
Version: 0x42,
},
PacketNumberLen: 2,
},
make([]byte, protocol.MinUnknownVersionPacketSize-100),
)
require.Greater(t, len(p.data), protocol.MinUnknownVersionPacketSize-100)
require.Less(t, len(p.data), protocol.MinUnknownVersionPacketSize)
testServerDroppedPacket(t,
conn,
p,
logging.PacketTypeNotDetermined,
logging.PacketDropUnexpectedPacket,
)
})
t.Run("not an Initial packet", func(t *testing.T) {
conn := newUDPConnLocalhost(t)
testServerDroppedPacket(t,
conn,
getLongHeaderPacket(t,
conn.LocalAddr(),
&wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
Version: protocol.Version1,
},
PacketNumberLen: 2,
},
nil,
),
logging.PacketTypeHandshake,
logging.PacketDropUnexpectedPacket,
)
})
// as a server, we should never receive a Version Negotiation packet
t.Run("Version Negotiation packet", func(t *testing.T) {
conn := newUDPConnLocalhost(t)
data := wire.ComposeVersionNegotiation(
protocol.ArbitraryLenConnectionID{1, 2, 3, 4},
protocol.ArbitraryLenConnectionID{4, 3, 2, 1},
[]protocol.Version{1, 2, 3},
)
testServerDroppedPacket(t,
conn,
receivedPacket{
remoteAddr: conn.LocalAddr(),
data: data,
buffer: getPacketBuffer(),
},
logging.PacketTypeVersionNegotiation,
logging.PacketDropUnexpectedPacket,
)
})
}
func testServerDroppedPacket(t *testing.T,
conn *net.UDPConn,
p receivedPacket,
expectedPacketType logging.PacketType,
expectedDropReason logging.PacketDropReason,
) {
readChan := make(chan struct{})
go func() {
defer close(readChan)
conn.ReadFrom(make([]byte, 1000))
}()
mockCtrl := gomock.NewController(t)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
server := newTestServer(t, &serverOpts{tracer: tracer})
mockTracer.EXPECT().DroppedPacket(p.remoteAddr, expectedPacketType, p.Size(), expectedDropReason)
server.handlePacket(p)
select {
case <-readChan:
t.Fatal("didn't expect to receive a packet")
case <-time.After(scaleDuration(5 * time.Millisecond)):
}
}
func TestServerVersionNegotiation(t *testing.T) {
t.Run("enabled", func(t *testing.T) {
testServerVersionNegotiation(t, true)
})
t.Run("disabled", func(t *testing.T) {
testServerVersionNegotiation(t, false)
})
}
func testServerVersionNegotiation(t *testing.T, enabled bool) {
mockCtrl := gomock.NewController(t)
conn := newUDPConnLocalhost(t)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
server := newTestServer(t, &serverOpts{
tracer: tracer,
disableVersionNegotiation: !enabled,
})
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
packet := getLongHeaderPacket(t, conn.LocalAddr(),
&wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeHandshake,
SrcConnectionID: srcConnID,
DestConnectionID: destConnID,
Version: 0x42,
},
PacketNumberLen: protocol.PacketNumberLen4,
},
make([]byte, protocol.MinUnknownVersionPacketSize),
)
switch enabled {
case true:
mockTracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any())
case false:
mockTracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeNotDetermined, packet.Size(), logging.PacketDropUnexpectedVersion)
}
written := make(chan []byte, 1)
go func() {
b := make([]byte, 1500)
n, _, _ := conn.ReadFrom(b)
written <- b[:n]
}()
server.handlePacket(packet)
switch enabled {
case true:
select {
case b := <-written:
require.True(t, wire.IsVersionNegotiationPacket(b))
dest, src, versions, err := wire.ParseVersionNegotiationPacket(b)
require.NoError(t, err)
require.Equal(t, protocol.ArbitraryLenConnectionID(srcConnID.Bytes()), dest)
require.Equal(t, protocol.ArbitraryLenConnectionID(destConnID.Bytes()), src)
require.NotContains(t, versions, protocol.Version(0x42))
case <-time.After(time.Second):
t.Fatal("timeout")
}
case false:
select {
case <-written:
t.Fatal("expected no version negotiation packet")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
}
}
func TestServerRetry(t *testing.T) {
mockCtrl := gomock.NewController(t)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
server := newTestServer(t, &serverOpts{tracer: tracer, useRetry: true})
conn := newUDPConnLocalhost(t)
packet := getLongHeaderPacket(t, conn.LocalAddr(),
&wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen4,
},
make([]byte, protocol.MinUnknownVersionPacketSize),
)
argsChan := make(chan sentPacketCallArgs, 1)
mockTracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(
func(_ net.Addr, hdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
argsChan <- sentPacketCallArgs{hdr: hdr}
},
)
server.handlePacket(packet)
checkRetry(t, conn, argsChan, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}))
}
func TestServerTokenValidation(t *testing.T) {
var tokenGeneratorKey handshake.TokenProtectorKey
rand.Read(tokenGeneratorKey[:])
tg := handshake.NewTokenGenerator(tokenGeneratorKey)
t.Run("retry token with invalid address", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
token, err := tg.NewRetryToken(
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
protocol.ConnectionID{},
protocol.ConnectionID{},
)
require.NoError(t, err)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
server := newTestServer(t, &serverOpts{
useRetry: true,
tracer: tracer,
tokenGeneratorKey: tokenGeneratorKey,
})
testServerTokenValidation(t, server, mockTracer, newUDPConnLocalhost(t), token, false, true, false)
})
t.Run("expired retry token", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
conn := newUDPConnLocalhost(t)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
server := newTestServer(t, &serverOpts{
useRetry: true,
tracer: tracer,
config: &Config{HandshakeIdleTimeout: time.Millisecond / 2},
tokenGeneratorKey: tokenGeneratorKey,
})
token, err := tg.NewRetryToken(conn.LocalAddr(), protocol.ConnectionID{}, protocol.ConnectionID{})
require.NoError(t, err)
// the maximum retry token age is equivalent to the handshake timeout
time.Sleep(time.Millisecond) // make sure the token is expired
testServerTokenValidation(t, server, mockTracer, conn, token, false, true, false)
})
// if the packet is corrupted, it will just be dropped (no INVALID_TOKEN nor Retry is sent)
t.Run("corrupted packet", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
server := newTestServer(t, &serverOpts{
useRetry: true,
tracer: tracer,
config: &Config{HandshakeIdleTimeout: time.Millisecond / 2},
tokenGeneratorKey: tokenGeneratorKey,
})
conn := newUDPConnLocalhost(t)
token, err := tg.NewRetryToken(conn.LocalAddr(), protocol.ConnectionID{}, protocol.ConnectionID{})
require.NoError(t, err)
time.Sleep(time.Millisecond) // make sure the token is expired
testServerTokenValidation(t, server, mockTracer, conn, token, true, false, true)
})
t.Run("invalid non-retry token", func(t *testing.T) {
var tokenGeneratorKey2 handshake.TokenProtectorKey
rand.Read(tokenGeneratorKey2[:])
mockCtrl := gomock.NewController(t)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
server := newTestServer(t, &serverOpts{
tokenGeneratorKey: tokenGeneratorKey2, // use a different key
useRetry: true,
tracer: tracer,
maxTokenAge: time.Millisecond,
})
conn := newUDPConnLocalhost(t)
token, err := tg.NewToken(conn.LocalAddr(), 10*time.Millisecond)
require.NoError(t, err)
time.Sleep(3 * time.Millisecond) // make sure the token is expired
testServerTokenValidation(t, server, mockTracer, conn, token, false, false, true)
})
t.Run("expired non-retry token", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
server := newTestServer(t, &serverOpts{
tokenGeneratorKey: tokenGeneratorKey,
useRetry: true,
tracer: tracer,
maxTokenAge: time.Millisecond,
})
conn := newUDPConnLocalhost(t)
token, err := tg.NewToken(conn.LocalAddr(), 100*time.Millisecond)
require.NoError(t, err)
time.Sleep(3 * time.Millisecond) // make sure the token is expired
testServerTokenValidation(t, server, mockTracer, conn, token, false, false, true)
})
}
func testServerTokenValidation(
t *testing.T,
server *testServer,
mockTracer *mocklogging.MockTracer,
conn *net.UDPConn,
token []byte,
corruptedPacket bool,
expectInvalidTokenConnectionClose bool,
expectRetry bool,
) {
hdr := wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
Token: token,
Length: protocol.MinInitialPacketSize + protocol.ByteCount(protocol.PacketNumberLen4) + 16,
Version: protocol.Version1,
}
packet := getLongHeaderPacketEncrypted(t,
conn.LocalAddr(),
&wire.ExtendedHeader{Header: hdr, PacketNumberLen: protocol.PacketNumberLen4},
make([]byte, protocol.MinInitialPacketSize),
)
if corruptedPacket {
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
done := make(chan struct{})
mockTracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(
func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
close(done)
},
)
server.handlePacket(packet)
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
return
}
argsChan := make(chan sentPacketCallArgs, 1)
if expectInvalidTokenConnectionClose || expectRetry {
mockTracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(
func(_ net.Addr, hdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
argsChan <- sentPacketCallArgs{hdr: hdr, frames: frames}
},
)
}
server.handlePacket(packet)
if expectInvalidTokenConnectionClose {
checkConnectionClose(t, conn, argsChan, hdr.DestConnectionID, hdr.SrcConnectionID, qerr.InvalidToken)
}
if expectRetry {
checkRetry(t, conn, argsChan, hdr.SrcConnectionID)
}
}
type connConstructorArgs struct {
ctx context.Context
connRunner connRunner
config *Config
origDestConnID protocol.ConnectionID
retrySrcConnID *protocol.ConnectionID
clientDestConnID protocol.ConnectionID
destConnID protocol.ConnectionID
srcConnID protocol.ConnectionID
}
type connConstructorRecorder struct {
ch chan connConstructorArgs
hooks []*connTestHooks
}
func newConnConstructorRecorder(hooks ...*connTestHooks) *connConstructorRecorder {
return &connConstructorRecorder{
ch: make(chan connConstructorArgs, len(hooks)),
hooks: hooks,
}
}
func (r *connConstructorRecorder) Args() <-chan connConstructorArgs { return r.ch }
func (r *connConstructorRecorder) NewConn(
ctx context.Context,
_ context.CancelCauseFunc,
_ sendConn,
connRunner connRunner,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
config *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ time.Duration,
_ *logging.ConnectionTracer,
_ utils.Logger,
_ protocol.Version,
) *wrappedConn {
r.ch <- connConstructorArgs{
ctx: ctx,
connRunner: connRunner,
config: config,
origDestConnID: origDestConnID,
retrySrcConnID: retrySrcConnID,
clientDestConnID: clientDestConnID,
destConnID: destConnID,
srcConnID: srcConnID,
}
hooks := r.hooks[0]
r.hooks = r.hooks[1:]
return &wrappedConn{testHooks: hooks}
}
func TestServerCreateConnection(t *testing.T) {
t.Run("without retry", func(t *testing.T) {
testServerCreateConnection(t, false)
})
t.Run("with retry", func(t *testing.T) {
testServerCreateConnection(t, true)
})
}
func testServerCreateConnection(t *testing.T, useRetry bool) {
tokenGeneratorKey := TokenGeneratorKey{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
tg := handshake.NewTokenGenerator(tokenGeneratorKey)
server := newTestServer(t, &serverOpts{
useRetry: useRetry,
tokenGeneratorKey: tokenGeneratorKey,
})
done := make(chan struct{}, 3)
handledPackets := make(chan receivedPacket, 1)
recorder := newConnConstructorRecorder(&connTestHooks{
run: func() error { done <- struct{}{}; return nil },
context: func() context.Context { done <- struct{}{}; return context.Background() },
handshakeComplete: func() <-chan struct{} { done <- struct{}{}; return make(chan struct{}) },
handlePacket: func(p receivedPacket) { handledPackets <- p },
})
server.newConn = recorder.NewConn
conn := newUDPConnLocalhost(t)
var token []byte
if useRetry {
var err error
token, err = tg.NewRetryToken(
conn.LocalAddr(),
protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}),
protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
)
require.NoError(t, err)
}
hdr := wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
Length: protocol.MinInitialPacketSize + protocol.ByteCount(protocol.PacketNumberLen4) + 16,
Token: token,
Version: protocol.Version1,
}
packet := getLongHeaderPacketEncrypted(t,
conn.LocalAddr(),
&wire.ExtendedHeader{Header: hdr, PacketNumberLen: protocol.PacketNumberLen4},
make([]byte, protocol.MinInitialPacketSize),
)
server.handlePacket(packet)
select {
case p := <-handledPackets:
require.Equal(t, packet, p)
case <-time.After(time.Second):
t.Fatal("timeout")
}
var args connConstructorArgs
select {
case args = <-recorder.Args():
case <-time.After(time.Second):
t.Fatal("timeout")
}
assert.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), args.destConnID)
assert.NotEqual(t, args.origDestConnID, args.srcConnID)
if useRetry {
assert.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), args.destConnID)
assert.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), args.origDestConnID)
assert.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), *args.retrySrcConnID)
} else {
assert.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), args.origDestConnID)
assert.Zero(t, args.retrySrcConnID)
}
for range 3 {
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
}
func TestServerClose(t *testing.T) {
var hooks []*connTestHooks
const numConns = 3
done := make(chan struct{}, numConns)
for range numConns {
hooks = append(hooks, &connTestHooks{
closeWithTransportError: func(TransportErrorCode) { done <- struct{}{} },
})
}
recorder := newConnConstructorRecorder(hooks...)
server := newTestServer(t, &serverOpts{newConn: recorder.NewConn})
for range numConns {
b := make([]byte, 10)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
server.handlePacket(getValidInitialPacket(t,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
randConnID(6),
connID,
))
select {
case <-recorder.Args():
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
server.Close()
// closing closes all handshaking connections with CONNECTION_REFUSED
for range numConns {
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
// Accept returns ErrServerClosed after closing
for range 5 {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err := server.Accept(ctx)
require.ErrorIs(t, err, ErrServerClosed)
require.ErrorIs(t, err, net.ErrClosed)
}
}
func TestServerGetConfigForClientAccept(t *testing.T) {
recorder := newConnConstructorRecorder(&connTestHooks{})
server := newTestServer(t, &serverOpts{
config: &Config{
GetConfigForClient: func(*ClientInfo) (*Config, error) {
return &Config{MaxIncomingStreams: 1234}, nil
},
},
newConn: recorder.NewConn,
})
conn := newUDPConnLocalhost(t)
packet := getValidInitialPacket(t,
conn.LocalAddr(),
protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
)
server.handlePacket(packet)
var args connConstructorArgs
select {
case args = <-recorder.Args():
require.EqualValues(t, 1234, args.config.MaxIncomingStreams)
case <-time.After(time.Second):
t.Fatal("timeout")
}
assert.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), args.destConnID)
assert.NotEqual(t, args.origDestConnID, args.srcConnID)
}
func TestServerGetConfigForClientReject(t *testing.T) {
mockCtrl := gomock.NewController(t)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
server := newTestServer(t, &serverOpts{
tracer: tracer,
config: &Config{
GetConfigForClient: func(*ClientInfo) (*Config, error) {
return nil, errors.New("rejected")
},
},
})
conn := newUDPConnLocalhost(t)
srcConnID := randConnID(6)
destConnID := randConnID(8)
p := getValidInitialPacket(t, conn.LocalAddr(), srcConnID, destConnID)
argsChan := make(chan sentPacketCallArgs, 1)
mockTracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(
func(_ net.Addr, hdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
argsChan <- sentPacketCallArgs{hdr: hdr, frames: frames}
},
)
server.handlePacket(p)
checkConnectionClose(t, conn, argsChan, destConnID, srcConnID, qerr.ConnectionRefused)
}
func TestServerReceiveQueue(t *testing.T) {
mockCtrl := gomock.NewController(t)
acceptConn := make(chan struct{})
defer close(acceptConn)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
newConnChan := make(chan struct{}, protocol.MaxServerUnprocessedPackets+2)
server := newTestServer(t, &serverOpts{
tracer: tracer,
newConn: func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *statelessResetter,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ time.Duration,
_ *logging.ConnectionTracer,
_ utils.Logger,
_ protocol.Version,
) *wrappedConn {
newConnChan <- struct{}{}
<-acceptConn
return &wrappedConn{testHooks: &connTestHooks{handlePacket: func(receivedPacket) {}}}
},
})
conn := newUDPConnLocalhost(t)
for i := range protocol.MaxServerUnprocessedPackets + 1 {
server.handlePacket(getValidInitialPacket(t, conn.LocalAddr(), randConnID(6), randConnID(8)))
// newConn blocks on the acceptConn channel, so this blocks the server's run loop
if i == 0 {
select {
case <-newConnChan:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
}
done := make(chan struct{})
mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketTypeNotDetermined, gomock.Any(), logging.PacketDropDOSPrevention).Do(
func(_ net.Addr, _ logging.PacketType, _ logging.ByteCount, _ logging.PacketDropReason) {
close(done)
},
)
server.handlePacket(getValidInitialPacket(t, conn.LocalAddr(), randConnID(6), randConnID(8)))
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerAccept(t *testing.T) {
t.Run("without accept early", func(t *testing.T) {
testServerAccept(t, false)
})
t.Run("with accept early", func(t *testing.T) {
testServerAccept(t, true)
})
}
func testServerAccept(t *testing.T, acceptEarly bool) {
mockCtrl := gomock.NewController(t)
ready := make(chan struct{})
hooks := &connTestHooks{}
if acceptEarly {
hooks.earlyConnReady = func() <-chan struct{} { return ready }
} else {
hooks.handshakeComplete = func() <-chan struct{} { return ready }
}
recorder := newConnConstructorRecorder(hooks)
tracer, _ := mocklogging.NewMockTracer(mockCtrl)
server := newTestServer(t, &serverOpts{
acceptEarly: acceptEarly,
tracer: tracer,
newConn: recorder.NewConn,
})
// Accept should respect the context
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := server.Accept(ctx)
require.ErrorIs(t, err, context.Canceled)
// establish a new connection, which then starts handshaking
server.handlePacket(getValidInitialPacket(t,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
randConnID(6),
randConnID(8),
))
accepted := make(chan error, 1)
go func() {
_, err := server.Accept(context.Background())
accepted <- err
}()
select {
case <-accepted:
t.Fatal("server accepted the connection too early")
case <-time.After(scaleDuration(5 * time.Millisecond)):
}
// now complete the handshake
close(ready)
select {
case err := <-accepted:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerAcceptHandshakeFailure(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
recorder := newConnConstructorRecorder(&connTestHooks{
context: func() context.Context { return ctx },
handshakeComplete: func() <-chan struct{} { return make(chan struct{}) },
})
server := newTestServer(t, &serverOpts{newConn: recorder.NewConn})
// establish a new connection, which then starts handshaking
server.handlePacket(getValidInitialPacket(t,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
randConnID(6),
randConnID(8),
))
accepted := make(chan error, 1)
go func() {
_, err := server.Accept(context.Background())
accepted <- err
}()
cancel()
select {
case <-accepted:
t.Fatal("server should not have accepted the connection")
case <-time.After(scaleDuration(5 * time.Millisecond)):
}
}
func TestServerAcceptQueue(t *testing.T) {
var conns []*connTestHooks
rejectedCloseError := make(chan TransportErrorCode, 1)
for i := range protocol.MaxAcceptQueueSize + 2 {
conn := &connTestHooks{
handshakeComplete: func() <-chan struct{} {
c := make(chan struct{})
close(c)
return c
},
}
conns = append(conns, conn)
if i == protocol.MaxAcceptQueueSize {
conn.closeWithTransportError = func(code TransportErrorCode) { rejectedCloseError <- code }
continue
}
}
recorder := newConnConstructorRecorder(conns...)
server := newTestServer(t, &serverOpts{newConn: recorder.NewConn})
for range protocol.MaxAcceptQueueSize {
b := make([]byte, 16)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
server.handlePacket(
getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(6), connID),
)
select {
case args := <-recorder.Args():
require.Equal(t, connID, args.origDestConnID)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
// wait for the connection to be enqueued
time.Sleep(scaleDuration(10 * time.Millisecond))
server.handlePacket(
getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(6), randConnID(8)),
)
select {
case <-recorder.Args():
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case code := <-rejectedCloseError:
require.Equal(t, ConnectionRefused, code)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// accept one connection, freeing up one slot in the accept queue
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err := server.Accept(ctx)
require.NoError(t, err)
// it's now possible to enqueue a new connection
server.handlePacket(
getValidInitialPacket(t,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
randConnID(6),
protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}),
),
)
select {
case args := <-recorder.Args():
require.Equal(t, protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), args.origDestConnID)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServer0RTTReordering(t *testing.T) {
mockCtrl := gomock.NewController(t)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
packets := make(chan receivedPacket, protocol.Max0RTTQueueLen+1)
done := make(chan struct{})
recorder := newConnConstructorRecorder(&connTestHooks{
handlePacket: func(p receivedPacket) { packets <- p },
earlyConnReady: func() <-chan struct{} { return make(chan struct{}) },
run: func() error { close(done); return nil },
})
server := newTestServer(t, &serverOpts{
acceptEarly: true,
tracer: tracer,
newConn: recorder.NewConn,
})
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
var zeroRTTPackets []receivedPacket
for range protocol.Max0RTTQueueLen {
p := getLongHeaderPacket(t,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
&wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketType0RTT,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: connID,
Length: 100,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen4,
},
make([]byte, 100),
)
server.handlePacket(p)
zeroRTTPackets = append(zeroRTTPackets, p)
}
// send one more packet, this one should be dropped
p := getLongHeaderPacket(t,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
&wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketType0RTT,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: connID,
Length: 100,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen4,
},
make([]byte, 100),
)
mockTracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
server.handlePacket(p)
// now receive the Initial
initial := getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(5), connID)
server.handlePacket(initial)
for i := range protocol.Max0RTTQueueLen + 1 {
select {
case p := <-packets:
if i == 0 {
require.Equal(t, initial.data, p.data)
} else {
require.Equal(t, zeroRTTPackets[i-1], p)
}
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServer0RTTQueueing(t *testing.T) {
mockCtrl := gomock.NewController(t)
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
server := newTestServer(t, &serverOpts{
acceptEarly: true,
tracer: tracer,
})
firstRcvTime := time.Now()
otherRcvTime := firstRcvTime.Add(protocol.Max0RTTQueueingDuration / 2)
var sizes []protocol.ByteCount
for i := range protocol.Max0RTTQueues {
b := make([]byte, 16)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
size := protocol.ByteCount(500 + i)
p := getLongHeaderPacket(t,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
&wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketType0RTT,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: connID,
Length: size,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen4,
},
make([]byte, size),
)
if i == 0 {
p.rcvTime = firstRcvTime
} else {
p.rcvTime = otherRcvTime
}
sizes = append(sizes, p.Size())
server.handlePacket(p)
}
// maximum number of 0-RTT queues is reached, further packets are dropped
p := getLongHeaderPacket(t,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
&wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketType0RTT,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
Length: 123,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen4,
},
make([]byte, 123),
)
dropped := make(chan struct{}, protocol.Max0RTTQueues)
mockTracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(
func(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) { dropped <- struct{}{} },
)
server.handlePacket(p)
select {
case <-dropped:
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.True(t, mockCtrl.Satisfied())
// There's no cleanup Go routine.
// Cleanup is triggered when new packets are received.
// 1. Receive one handshake packet, which triggers the cleanup of the first 0-RTT queue
mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketTypeHandshake, gomock.Any(), gomock.Any())
mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketType0RTT, sizes[0], gomock.Any()).Do(
func(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) { dropped <- struct{}{} },
)
triggerPacket := getLongHeaderPacket(t,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
&wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeHandshake,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}),
Length: 123,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen4,
},
make([]byte, 123),
)
triggerPacket.rcvTime = firstRcvTime.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond)
server.handlePacket(triggerPacket)
select {
case <-dropped:
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.True(t, mockCtrl.Satisfied())
// 2. Receive another handshake packet, which triggers the cleanup of the other 0-RTT queues
triggerPacket = getLongHeaderPacket(t,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
&wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeHandshake,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}),
Length: 124,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen4,
},
make([]byte, 124),
)
triggerPacket.rcvTime = otherRcvTime.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond)
mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketTypeHandshake, gomock.Any(), gomock.Any())
for i := range sizes[1:] {
mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketType0RTT, sizes[i+1], gomock.Any()).Do(
func(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) { dropped <- struct{}{} },
)
}
server.handlePacket(triggerPacket)
for range protocol.Max0RTTQueues - 1 {
select {
case <-dropped:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
}