forked from quic-go/quic-go
1293 lines
39 KiB
Go
1293 lines
39 KiB
Go
package quic
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"errors"
|
|
"net"
|
|
"slices"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/quic-go/quic-go/internal/handshake"
|
|
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
|
"github.com/quic-go/quic-go/internal/protocol"
|
|
"github.com/quic-go/quic-go/internal/qerr"
|
|
"github.com/quic-go/quic-go/internal/utils"
|
|
"github.com/quic-go/quic-go/internal/wire"
|
|
"github.com/quic-go/quic-go/logging"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
)
|
|
|
|
type testServer struct{ *baseServer }
|
|
|
|
type serverOpts struct {
|
|
tracer *logging.Tracer
|
|
config *Config
|
|
tokenGeneratorKey TokenGeneratorKey
|
|
maxTokenAge time.Duration
|
|
useRetry bool
|
|
disableVersionNegotiation bool
|
|
acceptEarly bool
|
|
newConn func(
|
|
context.Context,
|
|
context.CancelCauseFunc,
|
|
sendConn,
|
|
connRunner,
|
|
protocol.ConnectionID, // original dest connection ID
|
|
*protocol.ConnectionID, // retry src connection ID
|
|
protocol.ConnectionID, // client dest connection ID
|
|
protocol.ConnectionID, // destination connection ID
|
|
protocol.ConnectionID, // source connection ID
|
|
ConnectionIDGenerator,
|
|
*statelessResetter,
|
|
*Config,
|
|
*tls.Config,
|
|
*handshake.TokenGenerator,
|
|
bool, /* client address validated by an address validation token */
|
|
time.Duration,
|
|
*logging.ConnectionTracer,
|
|
utils.Logger,
|
|
protocol.Version,
|
|
) *wrappedConn
|
|
}
|
|
|
|
func newTestServer(t *testing.T, serverOpts *serverOpts) *testServer {
|
|
t.Helper()
|
|
c, err := wrapConn(newUDPConnLocalhost(t))
|
|
require.NoError(t, err)
|
|
verifySourceAddress := func(net.Addr) bool { return serverOpts.useRetry }
|
|
config := populateConfig(serverOpts.config)
|
|
tr := &Transport{Conn: newUDPConnLocalhost(t)}
|
|
tr.init(true)
|
|
s := newServer(
|
|
c,
|
|
(*packetHandlerMap)(tr),
|
|
&protocol.DefaultConnectionIDGenerator{},
|
|
&statelessResetter{},
|
|
func(ctx context.Context, _ *ClientInfo) (context.Context, error) { return ctx, nil },
|
|
&tls.Config{},
|
|
config,
|
|
serverOpts.tracer,
|
|
func() {},
|
|
serverOpts.tokenGeneratorKey,
|
|
serverOpts.maxTokenAge,
|
|
verifySourceAddress,
|
|
serverOpts.disableVersionNegotiation,
|
|
serverOpts.acceptEarly,
|
|
)
|
|
s.newConn = serverOpts.newConn
|
|
t.Cleanup(func() { s.Close() })
|
|
return &testServer{s}
|
|
}
|
|
|
|
func getLongHeaderPacketEncrypted(t *testing.T, remoteAddr net.Addr, extHdr *wire.ExtendedHeader, data []byte) receivedPacket {
|
|
t.Helper()
|
|
hdr := extHdr.Header
|
|
if hdr.Type != protocol.PacketTypeInitial {
|
|
t.Fatal("can only encrypt Initial packets")
|
|
}
|
|
p := getLongHeaderPacket(t, remoteAddr, extHdr, data)
|
|
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version)
|
|
n := len(p.data) - len(data) // length of the header
|
|
p.data = slices.Grow(p.data, 16)
|
|
_ = sealer.Seal(p.data[n:n], p.data[n:], extHdr.PacketNumber, p.data[:n])
|
|
p.data = p.data[:len(p.data)+16]
|
|
sealer.EncryptHeader(p.data[n:n+16], &p.data[0], p.data[n-int(extHdr.PacketNumberLen):n])
|
|
return p
|
|
}
|
|
|
|
func randConnID(l int) protocol.ConnectionID {
|
|
b := make([]byte, l)
|
|
rand.Read(b)
|
|
return protocol.ParseConnectionID(b)
|
|
}
|
|
|
|
func getValidInitialPacket(t *testing.T, raddr net.Addr, srcConnID, destConnID protocol.ConnectionID) receivedPacket {
|
|
t.Helper()
|
|
return getLongHeaderPacket(t,
|
|
raddr,
|
|
&wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
SrcConnectionID: srcConnID,
|
|
DestConnectionID: destConnID,
|
|
Length: protocol.MinInitialPacketSize,
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumberLen: protocol.PacketNumberLen4,
|
|
},
|
|
make([]byte, protocol.MinInitialPacketSize),
|
|
)
|
|
}
|
|
|
|
type sentPacketCallArgs struct {
|
|
hdr *logging.Header
|
|
frames []logging.Frame
|
|
}
|
|
|
|
// checkConnectionClose checks
|
|
// 1. the arguments of the SentPacket tracer call, and
|
|
// 2. reads and parses the packet sent by the server
|
|
func checkConnectionClose(
|
|
t *testing.T,
|
|
conn *net.UDPConn,
|
|
argChan <-chan sentPacketCallArgs,
|
|
expectedSrcConnID protocol.ConnectionID,
|
|
expectedDestConnID protocol.ConnectionID,
|
|
expectedErrorCode qerr.TransportErrorCode,
|
|
) {
|
|
t.Helper()
|
|
select {
|
|
case args := <-argChan:
|
|
require.Equal(t, protocol.PacketTypeInitial, args.hdr.Type)
|
|
require.Equal(t, expectedSrcConnID, args.hdr.SrcConnectionID)
|
|
require.Equal(t, expectedDestConnID, args.hdr.DestConnectionID)
|
|
require.Len(t, args.frames, 1)
|
|
require.IsType(t, &wire.ConnectionCloseFrame{}, args.frames[0])
|
|
ccf := args.frames[0].(*logging.ConnectionCloseFrame)
|
|
require.EqualValues(t, expectedErrorCode, ccf.ErrorCode)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
b := make([]byte, 1500)
|
|
n, _, err := conn.ReadFromUDP(b)
|
|
require.NoError(t, err)
|
|
parsedHdr, _, _, err := wire.ParsePacket(b[:n])
|
|
require.NoError(t, err)
|
|
require.Equal(t, protocol.PacketTypeInitial, parsedHdr.Type)
|
|
require.Equal(t, expectedSrcConnID, parsedHdr.SrcConnectionID)
|
|
require.Equal(t, expectedDestConnID, parsedHdr.DestConnectionID)
|
|
}
|
|
|
|
func checkRetry(t *testing.T,
|
|
conn *net.UDPConn,
|
|
argChan <-chan sentPacketCallArgs,
|
|
expectedDestConnID protocol.ConnectionID,
|
|
) {
|
|
t.Helper()
|
|
select {
|
|
case args := <-argChan:
|
|
require.Equal(t, protocol.PacketTypeRetry, args.hdr.Type)
|
|
require.Equal(t, expectedDestConnID, args.hdr.DestConnectionID)
|
|
require.NotNil(t, args.hdr.Token)
|
|
require.Empty(t, args.frames)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
b := make([]byte, 1500)
|
|
n, _, err := conn.ReadFromUDP(b)
|
|
require.NoError(t, err)
|
|
parsedHdr, _, _, err := wire.ParsePacket(b[:n])
|
|
require.NoError(t, err)
|
|
require.Equal(t, protocol.PacketTypeRetry, parsedHdr.Type)
|
|
require.Equal(t, expectedDestConnID, parsedHdr.DestConnectionID)
|
|
require.NotNil(t, parsedHdr.Token)
|
|
}
|
|
|
|
func TestListen(t *testing.T) {
|
|
_, err := ListenAddr("localhost:0", nil, nil)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "quic: tls.Config not set")
|
|
|
|
_, err = Listen(nil, &tls.Config{}, &Config{Versions: []protocol.Version{0x1234}})
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "invalid QUIC version: 0x1234")
|
|
}
|
|
|
|
func TestListenAddr(t *testing.T) {
|
|
_, err := ListenAddr("127.0.0.1", &tls.Config{}, &Config{})
|
|
require.Error(t, err)
|
|
require.IsType(t, &net.AddrError{}, err)
|
|
|
|
_, err = ListenAddr("1.1.1.1:1111", &tls.Config{}, &Config{})
|
|
require.Error(t, err)
|
|
require.IsType(t, &net.OpError{}, err)
|
|
|
|
ln, err := ListenAddr("127.0.0.1:0", &tls.Config{}, &Config{})
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
}
|
|
|
|
func TestServerPacketDropping(t *testing.T) {
|
|
t.Run("destination connection ID too short", func(t *testing.T) {
|
|
conn := newUDPConnLocalhost(t)
|
|
testServerDroppedPacket(t,
|
|
conn,
|
|
getValidInitialPacket(t, conn.LocalAddr(), randConnID(5), randConnID(7)),
|
|
logging.PacketTypeInitial,
|
|
logging.PacketDropUnexpectedPacket,
|
|
)
|
|
})
|
|
|
|
t.Run("Initial packet too small", func(t *testing.T) {
|
|
conn := newUDPConnLocalhost(t)
|
|
p := getLongHeaderPacket(t,
|
|
conn.LocalAddr(),
|
|
&wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
DestConnectionID: randConnID(8),
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumberLen: 2,
|
|
},
|
|
make([]byte, protocol.MinInitialPacketSize-100),
|
|
)
|
|
require.Greater(t, len(p.data), protocol.MinInitialPacketSize-100)
|
|
require.Less(t, len(p.data), protocol.MinInitialPacketSize)
|
|
testServerDroppedPacket(t,
|
|
conn,
|
|
p,
|
|
logging.PacketTypeInitial,
|
|
logging.PacketDropUnexpectedPacket,
|
|
)
|
|
})
|
|
|
|
// we should not send a Version Negotiation packet if the packet is smaller than 1200 bytes
|
|
t.Run("packet of unknown version, too small", func(t *testing.T) {
|
|
conn := newUDPConnLocalhost(t)
|
|
p := getLongHeaderPacket(t,
|
|
conn.LocalAddr(),
|
|
&wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
DestConnectionID: randConnID(8),
|
|
Version: 0x42,
|
|
},
|
|
PacketNumberLen: 2,
|
|
},
|
|
make([]byte, protocol.MinUnknownVersionPacketSize-100),
|
|
)
|
|
require.Greater(t, len(p.data), protocol.MinUnknownVersionPacketSize-100)
|
|
require.Less(t, len(p.data), protocol.MinUnknownVersionPacketSize)
|
|
testServerDroppedPacket(t,
|
|
conn,
|
|
p,
|
|
logging.PacketTypeNotDetermined,
|
|
logging.PacketDropUnexpectedPacket,
|
|
)
|
|
})
|
|
|
|
t.Run("not an Initial packet", func(t *testing.T) {
|
|
conn := newUDPConnLocalhost(t)
|
|
testServerDroppedPacket(t,
|
|
conn,
|
|
getLongHeaderPacket(t,
|
|
conn.LocalAddr(),
|
|
&wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumberLen: 2,
|
|
},
|
|
nil,
|
|
),
|
|
logging.PacketTypeHandshake,
|
|
logging.PacketDropUnexpectedPacket,
|
|
)
|
|
})
|
|
|
|
// as a server, we should never receive a Version Negotiation packet
|
|
t.Run("Version Negotiation packet", func(t *testing.T) {
|
|
conn := newUDPConnLocalhost(t)
|
|
data := wire.ComposeVersionNegotiation(
|
|
protocol.ArbitraryLenConnectionID{1, 2, 3, 4},
|
|
protocol.ArbitraryLenConnectionID{4, 3, 2, 1},
|
|
[]protocol.Version{1, 2, 3},
|
|
)
|
|
testServerDroppedPacket(t,
|
|
conn,
|
|
receivedPacket{
|
|
remoteAddr: conn.LocalAddr(),
|
|
data: data,
|
|
buffer: getPacketBuffer(),
|
|
},
|
|
logging.PacketTypeVersionNegotiation,
|
|
logging.PacketDropUnexpectedPacket,
|
|
)
|
|
})
|
|
}
|
|
|
|
func testServerDroppedPacket(t *testing.T,
|
|
conn *net.UDPConn,
|
|
p receivedPacket,
|
|
expectedPacketType logging.PacketType,
|
|
expectedDropReason logging.PacketDropReason,
|
|
) {
|
|
readChan := make(chan struct{})
|
|
go func() {
|
|
defer close(readChan)
|
|
conn.ReadFrom(make([]byte, 1000))
|
|
}()
|
|
mockCtrl := gomock.NewController(t)
|
|
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
|
server := newTestServer(t, &serverOpts{tracer: tracer})
|
|
|
|
mockTracer.EXPECT().DroppedPacket(p.remoteAddr, expectedPacketType, p.Size(), expectedDropReason)
|
|
server.handlePacket(p)
|
|
|
|
select {
|
|
case <-readChan:
|
|
t.Fatal("didn't expect to receive a packet")
|
|
case <-time.After(scaleDuration(5 * time.Millisecond)):
|
|
}
|
|
}
|
|
|
|
func TestServerVersionNegotiation(t *testing.T) {
|
|
t.Run("enabled", func(t *testing.T) {
|
|
testServerVersionNegotiation(t, true)
|
|
})
|
|
t.Run("disabled", func(t *testing.T) {
|
|
testServerVersionNegotiation(t, false)
|
|
})
|
|
}
|
|
|
|
func testServerVersionNegotiation(t *testing.T, enabled bool) {
|
|
mockCtrl := gomock.NewController(t)
|
|
conn := newUDPConnLocalhost(t)
|
|
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
|
server := newTestServer(t, &serverOpts{
|
|
tracer: tracer,
|
|
disableVersionNegotiation: !enabled,
|
|
})
|
|
|
|
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
|
|
destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
|
|
packet := getLongHeaderPacket(t, conn.LocalAddr(),
|
|
&wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
SrcConnectionID: srcConnID,
|
|
DestConnectionID: destConnID,
|
|
Version: 0x42,
|
|
},
|
|
PacketNumberLen: protocol.PacketNumberLen4,
|
|
},
|
|
make([]byte, protocol.MinUnknownVersionPacketSize),
|
|
)
|
|
|
|
switch enabled {
|
|
case true:
|
|
mockTracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any())
|
|
case false:
|
|
mockTracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeNotDetermined, packet.Size(), logging.PacketDropUnexpectedVersion)
|
|
}
|
|
|
|
written := make(chan []byte, 1)
|
|
go func() {
|
|
b := make([]byte, 1500)
|
|
n, _, _ := conn.ReadFrom(b)
|
|
written <- b[:n]
|
|
}()
|
|
server.handlePacket(packet)
|
|
|
|
switch enabled {
|
|
case true:
|
|
select {
|
|
case b := <-written:
|
|
require.True(t, wire.IsVersionNegotiationPacket(b))
|
|
dest, src, versions, err := wire.ParseVersionNegotiationPacket(b)
|
|
require.NoError(t, err)
|
|
require.Equal(t, protocol.ArbitraryLenConnectionID(srcConnID.Bytes()), dest)
|
|
require.Equal(t, protocol.ArbitraryLenConnectionID(destConnID.Bytes()), src)
|
|
require.NotContains(t, versions, protocol.Version(0x42))
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
case false:
|
|
select {
|
|
case <-written:
|
|
t.Fatal("expected no version negotiation packet")
|
|
case <-time.After(scaleDuration(10 * time.Millisecond)):
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestServerRetry(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
|
server := newTestServer(t, &serverOpts{tracer: tracer, useRetry: true})
|
|
|
|
conn := newUDPConnLocalhost(t)
|
|
|
|
packet := getLongHeaderPacket(t, conn.LocalAddr(),
|
|
&wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}),
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumberLen: protocol.PacketNumberLen4,
|
|
},
|
|
make([]byte, protocol.MinUnknownVersionPacketSize),
|
|
)
|
|
argsChan := make(chan sentPacketCallArgs, 1)
|
|
mockTracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(
|
|
func(_ net.Addr, hdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
|
|
argsChan <- sentPacketCallArgs{hdr: hdr}
|
|
},
|
|
)
|
|
server.handlePacket(packet)
|
|
checkRetry(t, conn, argsChan, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}))
|
|
}
|
|
|
|
func TestServerTokenValidation(t *testing.T) {
|
|
var tokenGeneratorKey handshake.TokenProtectorKey
|
|
rand.Read(tokenGeneratorKey[:])
|
|
tg := handshake.NewTokenGenerator(tokenGeneratorKey)
|
|
|
|
t.Run("retry token with invalid address", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
token, err := tg.NewRetryToken(
|
|
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
|
|
protocol.ConnectionID{},
|
|
protocol.ConnectionID{},
|
|
)
|
|
require.NoError(t, err)
|
|
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
|
server := newTestServer(t, &serverOpts{
|
|
useRetry: true,
|
|
tracer: tracer,
|
|
tokenGeneratorKey: tokenGeneratorKey,
|
|
})
|
|
|
|
testServerTokenValidation(t, server, mockTracer, newUDPConnLocalhost(t), token, false, true, false)
|
|
})
|
|
|
|
t.Run("expired retry token", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
conn := newUDPConnLocalhost(t)
|
|
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
|
server := newTestServer(t, &serverOpts{
|
|
useRetry: true,
|
|
tracer: tracer,
|
|
config: &Config{HandshakeIdleTimeout: time.Millisecond / 2},
|
|
tokenGeneratorKey: tokenGeneratorKey,
|
|
})
|
|
|
|
token, err := tg.NewRetryToken(conn.LocalAddr(), protocol.ConnectionID{}, protocol.ConnectionID{})
|
|
require.NoError(t, err)
|
|
// the maximum retry token age is equivalent to the handshake timeout
|
|
time.Sleep(time.Millisecond) // make sure the token is expired
|
|
testServerTokenValidation(t, server, mockTracer, conn, token, false, true, false)
|
|
})
|
|
|
|
// if the packet is corrupted, it will just be dropped (no INVALID_TOKEN nor Retry is sent)
|
|
t.Run("corrupted packet", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
|
server := newTestServer(t, &serverOpts{
|
|
useRetry: true,
|
|
tracer: tracer,
|
|
config: &Config{HandshakeIdleTimeout: time.Millisecond / 2},
|
|
tokenGeneratorKey: tokenGeneratorKey,
|
|
})
|
|
|
|
conn := newUDPConnLocalhost(t)
|
|
token, err := tg.NewRetryToken(conn.LocalAddr(), protocol.ConnectionID{}, protocol.ConnectionID{})
|
|
require.NoError(t, err)
|
|
time.Sleep(time.Millisecond) // make sure the token is expired
|
|
testServerTokenValidation(t, server, mockTracer, conn, token, true, false, true)
|
|
})
|
|
|
|
t.Run("invalid non-retry token", func(t *testing.T) {
|
|
var tokenGeneratorKey2 handshake.TokenProtectorKey
|
|
rand.Read(tokenGeneratorKey2[:])
|
|
mockCtrl := gomock.NewController(t)
|
|
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
|
server := newTestServer(t, &serverOpts{
|
|
tokenGeneratorKey: tokenGeneratorKey2, // use a different key
|
|
useRetry: true,
|
|
tracer: tracer,
|
|
maxTokenAge: time.Millisecond,
|
|
})
|
|
|
|
conn := newUDPConnLocalhost(t)
|
|
token, err := tg.NewToken(conn.LocalAddr(), 10*time.Millisecond)
|
|
require.NoError(t, err)
|
|
time.Sleep(3 * time.Millisecond) // make sure the token is expired
|
|
testServerTokenValidation(t, server, mockTracer, conn, token, false, false, true)
|
|
})
|
|
|
|
t.Run("expired non-retry token", func(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
|
server := newTestServer(t, &serverOpts{
|
|
tokenGeneratorKey: tokenGeneratorKey,
|
|
useRetry: true,
|
|
tracer: tracer,
|
|
maxTokenAge: time.Millisecond,
|
|
})
|
|
|
|
conn := newUDPConnLocalhost(t)
|
|
token, err := tg.NewToken(conn.LocalAddr(), 100*time.Millisecond)
|
|
require.NoError(t, err)
|
|
time.Sleep(3 * time.Millisecond) // make sure the token is expired
|
|
testServerTokenValidation(t, server, mockTracer, conn, token, false, false, true)
|
|
})
|
|
}
|
|
|
|
func testServerTokenValidation(
|
|
t *testing.T,
|
|
server *testServer,
|
|
mockTracer *mocklogging.MockTracer,
|
|
conn *net.UDPConn,
|
|
token []byte,
|
|
corruptedPacket bool,
|
|
expectInvalidTokenConnectionClose bool,
|
|
expectRetry bool,
|
|
) {
|
|
hdr := wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
|
Token: token,
|
|
Length: protocol.MinInitialPacketSize + protocol.ByteCount(protocol.PacketNumberLen4) + 16,
|
|
Version: protocol.Version1,
|
|
}
|
|
packet := getLongHeaderPacketEncrypted(t,
|
|
conn.LocalAddr(),
|
|
&wire.ExtendedHeader{Header: hdr, PacketNumberLen: protocol.PacketNumberLen4},
|
|
make([]byte, protocol.MinInitialPacketSize),
|
|
)
|
|
if corruptedPacket {
|
|
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
|
|
done := make(chan struct{})
|
|
mockTracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(
|
|
func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
|
close(done)
|
|
},
|
|
)
|
|
server.handlePacket(packet)
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
return
|
|
}
|
|
|
|
argsChan := make(chan sentPacketCallArgs, 1)
|
|
if expectInvalidTokenConnectionClose || expectRetry {
|
|
mockTracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(
|
|
func(_ net.Addr, hdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
|
argsChan <- sentPacketCallArgs{hdr: hdr, frames: frames}
|
|
},
|
|
)
|
|
}
|
|
|
|
server.handlePacket(packet)
|
|
if expectInvalidTokenConnectionClose {
|
|
checkConnectionClose(t, conn, argsChan, hdr.DestConnectionID, hdr.SrcConnectionID, qerr.InvalidToken)
|
|
}
|
|
if expectRetry {
|
|
checkRetry(t, conn, argsChan, hdr.SrcConnectionID)
|
|
}
|
|
}
|
|
|
|
type connConstructorArgs struct {
|
|
ctx context.Context
|
|
connRunner connRunner
|
|
config *Config
|
|
origDestConnID protocol.ConnectionID
|
|
retrySrcConnID *protocol.ConnectionID
|
|
clientDestConnID protocol.ConnectionID
|
|
destConnID protocol.ConnectionID
|
|
srcConnID protocol.ConnectionID
|
|
}
|
|
|
|
type connConstructorRecorder struct {
|
|
ch chan connConstructorArgs
|
|
|
|
hooks []*connTestHooks
|
|
}
|
|
|
|
func newConnConstructorRecorder(hooks ...*connTestHooks) *connConstructorRecorder {
|
|
return &connConstructorRecorder{
|
|
ch: make(chan connConstructorArgs, len(hooks)),
|
|
hooks: hooks,
|
|
}
|
|
}
|
|
|
|
func (r *connConstructorRecorder) Args() <-chan connConstructorArgs { return r.ch }
|
|
|
|
func (r *connConstructorRecorder) NewConn(
|
|
ctx context.Context,
|
|
_ context.CancelCauseFunc,
|
|
_ sendConn,
|
|
connRunner connRunner,
|
|
origDestConnID protocol.ConnectionID,
|
|
retrySrcConnID *protocol.ConnectionID,
|
|
clientDestConnID protocol.ConnectionID,
|
|
destConnID protocol.ConnectionID,
|
|
srcConnID protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
_ *statelessResetter,
|
|
config *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ time.Duration,
|
|
_ *logging.ConnectionTracer,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) *wrappedConn {
|
|
r.ch <- connConstructorArgs{
|
|
ctx: ctx,
|
|
connRunner: connRunner,
|
|
config: config,
|
|
origDestConnID: origDestConnID,
|
|
retrySrcConnID: retrySrcConnID,
|
|
clientDestConnID: clientDestConnID,
|
|
destConnID: destConnID,
|
|
srcConnID: srcConnID,
|
|
}
|
|
hooks := r.hooks[0]
|
|
r.hooks = r.hooks[1:]
|
|
return &wrappedConn{testHooks: hooks}
|
|
}
|
|
|
|
func TestServerCreateConnection(t *testing.T) {
|
|
t.Run("without retry", func(t *testing.T) {
|
|
testServerCreateConnection(t, false)
|
|
})
|
|
t.Run("with retry", func(t *testing.T) {
|
|
testServerCreateConnection(t, true)
|
|
})
|
|
}
|
|
|
|
func testServerCreateConnection(t *testing.T, useRetry bool) {
|
|
tokenGeneratorKey := TokenGeneratorKey{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
|
|
tg := handshake.NewTokenGenerator(tokenGeneratorKey)
|
|
|
|
server := newTestServer(t, &serverOpts{
|
|
useRetry: useRetry,
|
|
tokenGeneratorKey: tokenGeneratorKey,
|
|
})
|
|
|
|
done := make(chan struct{}, 3)
|
|
handledPackets := make(chan receivedPacket, 1)
|
|
recorder := newConnConstructorRecorder(&connTestHooks{
|
|
run: func() error { done <- struct{}{}; return nil },
|
|
context: func() context.Context { done <- struct{}{}; return context.Background() },
|
|
handshakeComplete: func() <-chan struct{} { done <- struct{}{}; return make(chan struct{}) },
|
|
handlePacket: func(p receivedPacket) { handledPackets <- p },
|
|
})
|
|
server.newConn = recorder.NewConn
|
|
|
|
conn := newUDPConnLocalhost(t)
|
|
var token []byte
|
|
if useRetry {
|
|
var err error
|
|
token, err = tg.NewRetryToken(
|
|
conn.LocalAddr(),
|
|
protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}),
|
|
protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
|
|
)
|
|
require.NoError(t, err)
|
|
}
|
|
hdr := wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
|
Length: protocol.MinInitialPacketSize + protocol.ByteCount(protocol.PacketNumberLen4) + 16,
|
|
Token: token,
|
|
Version: protocol.Version1,
|
|
}
|
|
packet := getLongHeaderPacketEncrypted(t,
|
|
conn.LocalAddr(),
|
|
&wire.ExtendedHeader{Header: hdr, PacketNumberLen: protocol.PacketNumberLen4},
|
|
make([]byte, protocol.MinInitialPacketSize),
|
|
)
|
|
|
|
server.handlePacket(packet)
|
|
|
|
select {
|
|
case p := <-handledPackets:
|
|
require.Equal(t, packet, p)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
var args connConstructorArgs
|
|
select {
|
|
case args = <-recorder.Args():
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
assert.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), args.destConnID)
|
|
assert.NotEqual(t, args.origDestConnID, args.srcConnID)
|
|
if useRetry {
|
|
assert.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), args.destConnID)
|
|
assert.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), args.origDestConnID)
|
|
assert.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), *args.retrySrcConnID)
|
|
} else {
|
|
assert.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), args.origDestConnID)
|
|
assert.Zero(t, args.retrySrcConnID)
|
|
}
|
|
|
|
for range 3 {
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestServerClose(t *testing.T) {
|
|
var hooks []*connTestHooks
|
|
const numConns = 3
|
|
done := make(chan struct{}, numConns)
|
|
for range numConns {
|
|
hooks = append(hooks, &connTestHooks{
|
|
closeWithTransportError: func(TransportErrorCode) { done <- struct{}{} },
|
|
})
|
|
}
|
|
recorder := newConnConstructorRecorder(hooks...)
|
|
server := newTestServer(t, &serverOpts{newConn: recorder.NewConn})
|
|
|
|
for range numConns {
|
|
b := make([]byte, 10)
|
|
rand.Read(b)
|
|
connID := protocol.ParseConnectionID(b)
|
|
server.handlePacket(getValidInitialPacket(t,
|
|
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
|
|
randConnID(6),
|
|
connID,
|
|
))
|
|
select {
|
|
case <-recorder.Args():
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
server.Close()
|
|
// closing closes all handshaking connections with CONNECTION_REFUSED
|
|
for range numConns {
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
// Accept returns ErrServerClosed after closing
|
|
for range 5 {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
_, err := server.Accept(ctx)
|
|
require.ErrorIs(t, err, ErrServerClosed)
|
|
require.ErrorIs(t, err, net.ErrClosed)
|
|
}
|
|
}
|
|
|
|
func TestServerGetConfigForClientAccept(t *testing.T) {
|
|
recorder := newConnConstructorRecorder(&connTestHooks{})
|
|
server := newTestServer(t, &serverOpts{
|
|
config: &Config{
|
|
GetConfigForClient: func(*ClientInfo) (*Config, error) {
|
|
return &Config{MaxIncomingStreams: 1234}, nil
|
|
},
|
|
},
|
|
newConn: recorder.NewConn,
|
|
})
|
|
|
|
conn := newUDPConnLocalhost(t)
|
|
packet := getValidInitialPacket(t,
|
|
conn.LocalAddr(),
|
|
protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
|
|
)
|
|
|
|
server.handlePacket(packet)
|
|
|
|
var args connConstructorArgs
|
|
select {
|
|
case args = <-recorder.Args():
|
|
require.EqualValues(t, 1234, args.config.MaxIncomingStreams)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
assert.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), args.destConnID)
|
|
assert.NotEqual(t, args.origDestConnID, args.srcConnID)
|
|
}
|
|
|
|
func TestServerGetConfigForClientReject(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
|
server := newTestServer(t, &serverOpts{
|
|
tracer: tracer,
|
|
config: &Config{
|
|
GetConfigForClient: func(*ClientInfo) (*Config, error) {
|
|
return nil, errors.New("rejected")
|
|
},
|
|
},
|
|
})
|
|
|
|
conn := newUDPConnLocalhost(t)
|
|
srcConnID := randConnID(6)
|
|
destConnID := randConnID(8)
|
|
p := getValidInitialPacket(t, conn.LocalAddr(), srcConnID, destConnID)
|
|
argsChan := make(chan sentPacketCallArgs, 1)
|
|
mockTracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(
|
|
func(_ net.Addr, hdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
|
argsChan <- sentPacketCallArgs{hdr: hdr, frames: frames}
|
|
},
|
|
)
|
|
server.handlePacket(p)
|
|
|
|
checkConnectionClose(t, conn, argsChan, destConnID, srcConnID, qerr.ConnectionRefused)
|
|
}
|
|
|
|
func TestServerReceiveQueue(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
acceptConn := make(chan struct{})
|
|
defer close(acceptConn)
|
|
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
|
newConnChan := make(chan struct{}, protocol.MaxServerUnprocessedPackets+2)
|
|
server := newTestServer(t, &serverOpts{
|
|
tracer: tracer,
|
|
newConn: func(
|
|
_ context.Context,
|
|
_ context.CancelCauseFunc,
|
|
_ sendConn,
|
|
_ connRunner,
|
|
_ protocol.ConnectionID,
|
|
_ *protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
_ *statelessResetter,
|
|
_ *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ time.Duration,
|
|
_ *logging.ConnectionTracer,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) *wrappedConn {
|
|
newConnChan <- struct{}{}
|
|
<-acceptConn
|
|
return &wrappedConn{testHooks: &connTestHooks{handlePacket: func(receivedPacket) {}}}
|
|
},
|
|
})
|
|
|
|
conn := newUDPConnLocalhost(t)
|
|
for i := range protocol.MaxServerUnprocessedPackets + 1 {
|
|
server.handlePacket(getValidInitialPacket(t, conn.LocalAddr(), randConnID(6), randConnID(8)))
|
|
// newConn blocks on the acceptConn channel, so this blocks the server's run loop
|
|
if i == 0 {
|
|
select {
|
|
case <-newConnChan:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketTypeNotDetermined, gomock.Any(), logging.PacketDropDOSPrevention).Do(
|
|
func(_ net.Addr, _ logging.PacketType, _ logging.ByteCount, _ logging.PacketDropReason) {
|
|
close(done)
|
|
},
|
|
)
|
|
server.handlePacket(getValidInitialPacket(t, conn.LocalAddr(), randConnID(6), randConnID(8)))
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestServerAccept(t *testing.T) {
|
|
t.Run("without accept early", func(t *testing.T) {
|
|
testServerAccept(t, false)
|
|
})
|
|
t.Run("with accept early", func(t *testing.T) {
|
|
testServerAccept(t, true)
|
|
})
|
|
}
|
|
|
|
func testServerAccept(t *testing.T, acceptEarly bool) {
|
|
mockCtrl := gomock.NewController(t)
|
|
ready := make(chan struct{})
|
|
hooks := &connTestHooks{}
|
|
if acceptEarly {
|
|
hooks.earlyConnReady = func() <-chan struct{} { return ready }
|
|
} else {
|
|
hooks.handshakeComplete = func() <-chan struct{} { return ready }
|
|
}
|
|
recorder := newConnConstructorRecorder(hooks)
|
|
tracer, _ := mocklogging.NewMockTracer(mockCtrl)
|
|
server := newTestServer(t, &serverOpts{
|
|
acceptEarly: acceptEarly,
|
|
tracer: tracer,
|
|
newConn: recorder.NewConn,
|
|
})
|
|
|
|
// Accept should respect the context
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
_, err := server.Accept(ctx)
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
|
|
// establish a new connection, which then starts handshaking
|
|
server.handlePacket(getValidInitialPacket(t,
|
|
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
|
|
randConnID(6),
|
|
randConnID(8),
|
|
))
|
|
|
|
accepted := make(chan error, 1)
|
|
go func() {
|
|
_, err := server.Accept(context.Background())
|
|
accepted <- err
|
|
}()
|
|
|
|
select {
|
|
case <-accepted:
|
|
t.Fatal("server accepted the connection too early")
|
|
case <-time.After(scaleDuration(5 * time.Millisecond)):
|
|
}
|
|
// now complete the handshake
|
|
close(ready)
|
|
|
|
select {
|
|
case err := <-accepted:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestServerAcceptHandshakeFailure(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
recorder := newConnConstructorRecorder(&connTestHooks{
|
|
context: func() context.Context { return ctx },
|
|
handshakeComplete: func() <-chan struct{} { return make(chan struct{}) },
|
|
})
|
|
server := newTestServer(t, &serverOpts{newConn: recorder.NewConn})
|
|
|
|
// establish a new connection, which then starts handshaking
|
|
server.handlePacket(getValidInitialPacket(t,
|
|
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
|
|
randConnID(6),
|
|
randConnID(8),
|
|
))
|
|
|
|
accepted := make(chan error, 1)
|
|
go func() {
|
|
_, err := server.Accept(context.Background())
|
|
accepted <- err
|
|
}()
|
|
|
|
cancel()
|
|
select {
|
|
case <-accepted:
|
|
t.Fatal("server should not have accepted the connection")
|
|
case <-time.After(scaleDuration(5 * time.Millisecond)):
|
|
}
|
|
}
|
|
|
|
func TestServerAcceptQueue(t *testing.T) {
|
|
var conns []*connTestHooks
|
|
rejectedCloseError := make(chan TransportErrorCode, 1)
|
|
for i := range protocol.MaxAcceptQueueSize + 2 {
|
|
conn := &connTestHooks{
|
|
handshakeComplete: func() <-chan struct{} {
|
|
c := make(chan struct{})
|
|
close(c)
|
|
return c
|
|
},
|
|
}
|
|
conns = append(conns, conn)
|
|
if i == protocol.MaxAcceptQueueSize {
|
|
conn.closeWithTransportError = func(code TransportErrorCode) { rejectedCloseError <- code }
|
|
continue
|
|
}
|
|
}
|
|
recorder := newConnConstructorRecorder(conns...)
|
|
server := newTestServer(t, &serverOpts{newConn: recorder.NewConn})
|
|
|
|
for range protocol.MaxAcceptQueueSize {
|
|
b := make([]byte, 16)
|
|
rand.Read(b)
|
|
connID := protocol.ParseConnectionID(b)
|
|
server.handlePacket(
|
|
getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(6), connID),
|
|
)
|
|
select {
|
|
case args := <-recorder.Args():
|
|
require.Equal(t, connID, args.origDestConnID)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
// wait for the connection to be enqueued
|
|
time.Sleep(scaleDuration(10 * time.Millisecond))
|
|
|
|
server.handlePacket(
|
|
getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(6), randConnID(8)),
|
|
)
|
|
select {
|
|
case <-recorder.Args():
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
select {
|
|
case code := <-rejectedCloseError:
|
|
require.Equal(t, ConnectionRefused, code)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// accept one connection, freeing up one slot in the accept queue
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
_, err := server.Accept(ctx)
|
|
require.NoError(t, err)
|
|
|
|
// it's now possible to enqueue a new connection
|
|
server.handlePacket(
|
|
getValidInitialPacket(t,
|
|
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
|
|
randConnID(6),
|
|
protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}),
|
|
),
|
|
)
|
|
select {
|
|
case args := <-recorder.Args():
|
|
require.Equal(t, protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), args.origDestConnID)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestServer0RTTReordering(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
|
packets := make(chan receivedPacket, protocol.Max0RTTQueueLen+1)
|
|
done := make(chan struct{})
|
|
recorder := newConnConstructorRecorder(&connTestHooks{
|
|
handlePacket: func(p receivedPacket) { packets <- p },
|
|
earlyConnReady: func() <-chan struct{} { return make(chan struct{}) },
|
|
run: func() error { close(done); return nil },
|
|
})
|
|
server := newTestServer(t, &serverOpts{
|
|
acceptEarly: true,
|
|
tracer: tracer,
|
|
newConn: recorder.NewConn,
|
|
})
|
|
|
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
|
|
var zeroRTTPackets []receivedPacket
|
|
|
|
for range protocol.Max0RTTQueueLen {
|
|
p := getLongHeaderPacket(t,
|
|
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
|
|
&wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketType0RTT,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: connID,
|
|
Length: 100,
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumberLen: protocol.PacketNumberLen4,
|
|
},
|
|
make([]byte, 100),
|
|
)
|
|
server.handlePacket(p)
|
|
zeroRTTPackets = append(zeroRTTPackets, p)
|
|
}
|
|
|
|
// send one more packet, this one should be dropped
|
|
p := getLongHeaderPacket(t,
|
|
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
|
|
&wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketType0RTT,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: connID,
|
|
Length: 100,
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumberLen: protocol.PacketNumberLen4,
|
|
},
|
|
make([]byte, 100),
|
|
)
|
|
mockTracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
|
server.handlePacket(p)
|
|
|
|
// now receive the Initial
|
|
initial := getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(5), connID)
|
|
server.handlePacket(initial)
|
|
|
|
for i := range protocol.Max0RTTQueueLen + 1 {
|
|
select {
|
|
case p := <-packets:
|
|
if i == 0 {
|
|
require.Equal(t, initial.data, p.data)
|
|
} else {
|
|
require.Equal(t, zeroRTTPackets[i-1], p)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestServer0RTTQueueing(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl)
|
|
server := newTestServer(t, &serverOpts{
|
|
acceptEarly: true,
|
|
tracer: tracer,
|
|
})
|
|
|
|
firstRcvTime := time.Now()
|
|
otherRcvTime := firstRcvTime.Add(protocol.Max0RTTQueueingDuration / 2)
|
|
var sizes []protocol.ByteCount
|
|
for i := range protocol.Max0RTTQueues {
|
|
b := make([]byte, 16)
|
|
rand.Read(b)
|
|
connID := protocol.ParseConnectionID(b)
|
|
size := protocol.ByteCount(500 + i)
|
|
p := getLongHeaderPacket(t,
|
|
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
|
|
&wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketType0RTT,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: connID,
|
|
Length: size,
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumberLen: protocol.PacketNumberLen4,
|
|
},
|
|
make([]byte, size),
|
|
)
|
|
if i == 0 {
|
|
p.rcvTime = firstRcvTime
|
|
} else {
|
|
p.rcvTime = otherRcvTime
|
|
}
|
|
sizes = append(sizes, p.Size())
|
|
server.handlePacket(p)
|
|
}
|
|
|
|
// maximum number of 0-RTT queues is reached, further packets are dropped
|
|
p := getLongHeaderPacket(t,
|
|
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
|
|
&wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketType0RTT,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
|
|
Length: 123,
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumberLen: protocol.PacketNumberLen4,
|
|
},
|
|
make([]byte, 123),
|
|
)
|
|
dropped := make(chan struct{}, protocol.Max0RTTQueues)
|
|
mockTracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(
|
|
func(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) { dropped <- struct{}{} },
|
|
)
|
|
server.handlePacket(p)
|
|
select {
|
|
case <-dropped:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// There's no cleanup Go routine.
|
|
// Cleanup is triggered when new packets are received.
|
|
// 1. Receive one handshake packet, which triggers the cleanup of the first 0-RTT queue
|
|
mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketTypeHandshake, gomock.Any(), gomock.Any())
|
|
mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketType0RTT, sizes[0], gomock.Any()).Do(
|
|
func(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) { dropped <- struct{}{} },
|
|
)
|
|
triggerPacket := getLongHeaderPacket(t,
|
|
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
|
|
&wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}),
|
|
Length: 123,
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumberLen: protocol.PacketNumberLen4,
|
|
},
|
|
make([]byte, 123),
|
|
)
|
|
triggerPacket.rcvTime = firstRcvTime.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond)
|
|
server.handlePacket(triggerPacket)
|
|
select {
|
|
case <-dropped:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
require.True(t, mockCtrl.Satisfied())
|
|
|
|
// 2. Receive another handshake packet, which triggers the cleanup of the other 0-RTT queues
|
|
triggerPacket = getLongHeaderPacket(t,
|
|
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42},
|
|
&wire.ExtendedHeader{
|
|
Header: wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}),
|
|
Length: 124,
|
|
Version: protocol.Version1,
|
|
},
|
|
PacketNumberLen: protocol.PacketNumberLen4,
|
|
},
|
|
make([]byte, 124),
|
|
)
|
|
triggerPacket.rcvTime = otherRcvTime.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond)
|
|
mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketTypeHandshake, gomock.Any(), gomock.Any())
|
|
for i := range sizes[1:] {
|
|
mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketType0RTT, sizes[i+1], gomock.Any()).Do(
|
|
func(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) { dropped <- struct{}{} },
|
|
)
|
|
}
|
|
server.handlePacket(triggerPacket)
|
|
|
|
for range protocol.Max0RTTQueues - 1 {
|
|
select {
|
|
case <-dropped:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
}
|