diff --git a/go.mod b/go.mod index d57589f6..8c1b7063 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,6 @@ require ( golang.org/x/net v0.28.0 golang.org/x/sync v0.8.0 golang.org/x/sys v0.23.0 - golang.org/x/time v0.5.0 golang.org/x/tools v0.22.0 ) diff --git a/go.sum b/go.sum index e05c6fd8..ab8b2299 100644 --- a/go.sum +++ b/go.sum @@ -191,8 +191,6 @@ golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/integrationtests/gomodvendor/go.sum b/integrationtests/gomodvendor/go.sum index 836f14b2..1b6e5fce 100644 --- a/integrationtests/gomodvendor/go.sum +++ b/integrationtests/gomodvendor/go.sum @@ -44,8 +44,6 @@ golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= diff --git a/server_test.go b/server_test.go index 5d3535d9..03e2a44d 100644 --- a/server_test.go +++ b/server_test.go @@ -6,1521 +6,1315 @@ import ( "crypto/tls" "errors" "net" - "sync" - "sync/atomic" + "slices" + "testing" "time" - "golang.org/x/time/rate" - "github.com/quic-go/quic-go/internal/handshake" mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" - "github.com/quic-go/quic-go/internal/testdata" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/logging" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) -var _ = Describe("Server", func() { - var ( - conn *MockPacketConn - tlsConf *tls.Config +type testServer struct{ *baseServer } + +type serverOpts struct { + tracer *logging.Tracer + config *Config + tokenGeneratorKey TokenGeneratorKey + maxTokenAge time.Duration + useRetry bool + disableVersionNegotiation bool + acceptEarly bool + newConn func( + context.Context, + context.CancelCauseFunc, + sendConn, + connRunner, + protocol.ConnectionID, // original dest connection ID + *protocol.ConnectionID, // retry src connection ID + protocol.ConnectionID, // client dest connection ID + protocol.ConnectionID, // destination connection ID + protocol.ConnectionID, // source connection ID + ConnectionIDGenerator, + *statelessResetter, + *Config, + *tls.Config, + *handshake.TokenGenerator, + bool, /* client address validated by an address validation token */ + *logging.ConnectionTracer, + utils.Logger, + protocol.Version, + ) quicConn +} + +func newTestServer(t *testing.T, serverOpts *serverOpts) *testServer { + t.Helper() + c, err := wrapConn(newUDPConnLocalhost(t)) + require.NoError(t, err) + verifySourceAddress := func(net.Addr) bool { return serverOpts.useRetry } + config := populateConfig(serverOpts.config) + s := newServer( + c, + newPacketHandlerMap(nil, utils.DefaultLogger), + &protocol.DefaultConnectionIDGenerator{}, + &statelessResetter{}, + func(ctx context.Context) context.Context { return ctx }, + &tls.Config{}, + config, + serverOpts.tracer, + func() {}, + serverOpts.tokenGeneratorKey, + serverOpts.maxTokenAge, + verifySourceAddress, + serverOpts.disableVersionNegotiation, + serverOpts.acceptEarly, + ) + s.newConn = serverOpts.newConn + t.Cleanup(func() { s.Close() }) + return &testServer{s} +} + +func getLongHeaderPacketEncrypted(t *testing.T, remoteAddr net.Addr, extHdr *wire.ExtendedHeader, data []byte) receivedPacket { + t.Helper() + hdr := extHdr.Header + if hdr.Type != protocol.PacketTypeInitial { + t.Fatal("can only encrypt Initial packets") + } + p := getLongHeaderPacket(t, remoteAddr, extHdr, data) + sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version) + n := len(p.data) - len(data) // length of the header + p.data = slices.Grow(p.data, 16) + _ = sealer.Seal(p.data[n:n], p.data[n:], extHdr.PacketNumber, p.data[:n]) + p.data = p.data[:len(p.data)+16] + sealer.EncryptHeader(p.data[n:n+16], &p.data[0], p.data[n-int(extHdr.PacketNumberLen):n]) + return p +} + +func randConnID(l int) protocol.ConnectionID { + b := make([]byte, l) + rand.Read(b) + return protocol.ParseConnectionID(b) +} + +func getValidInitialPacket(t *testing.T, raddr net.Addr, srcConnID, destConnID protocol.ConnectionID) receivedPacket { + t.Helper() + return getLongHeaderPacket(t, + raddr, + &wire.ExtendedHeader{ + Header: wire.Header{ + Type: protocol.PacketTypeInitial, + SrcConnectionID: srcConnID, + DestConnectionID: destConnID, + Length: protocol.MinInitialPacketSize, + Version: protocol.Version1, + }, + PacketNumberLen: protocol.PacketNumberLen4, + }, + make([]byte, protocol.MinInitialPacketSize), + ) +} + +type sentPacketCallArgs struct { + hdr *logging.Header + frames []logging.Frame +} + +// checkConnectionClose checks +// 1. the arguments of the SentPacket tracer call, and +// 2. reads and parses the packet sent by the server +func checkConnectionClose( + t *testing.T, + conn *net.UDPConn, + argChan <-chan sentPacketCallArgs, + expectedSrcConnID protocol.ConnectionID, + expectedDestConnID protocol.ConnectionID, + expectedErrorCode qerr.TransportErrorCode, +) { + t.Helper() + select { + case args := <-argChan: + require.Equal(t, protocol.PacketTypeInitial, args.hdr.Type) + require.Equal(t, expectedSrcConnID, args.hdr.SrcConnectionID) + require.Equal(t, expectedDestConnID, args.hdr.DestConnectionID) + require.Len(t, args.frames, 1) + require.IsType(t, &wire.ConnectionCloseFrame{}, args.frames[0]) + ccf := args.frames[0].(*logging.ConnectionCloseFrame) + require.EqualValues(t, expectedErrorCode, ccf.ErrorCode) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + conn.SetReadDeadline(time.Now().Add(time.Second)) + b := make([]byte, 1500) + n, _, err := conn.ReadFromUDP(b) + require.NoError(t, err) + parsedHdr, _, _, err := wire.ParsePacket(b[:n]) + require.NoError(t, err) + require.Equal(t, protocol.PacketTypeInitial, parsedHdr.Type) + require.Equal(t, expectedSrcConnID, parsedHdr.SrcConnectionID) + require.Equal(t, expectedDestConnID, parsedHdr.DestConnectionID) +} + +func checkRetry(t *testing.T, + conn *net.UDPConn, + argChan <-chan sentPacketCallArgs, + expectedDestConnID protocol.ConnectionID, +) { + t.Helper() + select { + case args := <-argChan: + require.Equal(t, protocol.PacketTypeRetry, args.hdr.Type) + require.Equal(t, expectedDestConnID, args.hdr.DestConnectionID) + require.NotNil(t, args.hdr.Token) + require.Empty(t, args.frames) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + conn.SetReadDeadline(time.Now().Add(time.Second)) + b := make([]byte, 1500) + n, _, err := conn.ReadFromUDP(b) + require.NoError(t, err) + parsedHdr, _, _, err := wire.ParsePacket(b[:n]) + require.NoError(t, err) + require.Equal(t, protocol.PacketTypeRetry, parsedHdr.Type) + require.Equal(t, expectedDestConnID, parsedHdr.DestConnectionID) + require.NotNil(t, parsedHdr.Token) +} + +func TestListen(t *testing.T) { + _, err := ListenAddr("localhost:0", nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "quic: tls.Config not set") + + _, err = Listen(nil, &tls.Config{}, &Config{Versions: []protocol.Version{0x1234}}) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid QUIC version: 0x1234") +} + +func TestListenAddr(t *testing.T) { + _, err := ListenAddr("127.0.0.1", &tls.Config{}, &Config{}) + require.Error(t, err) + require.IsType(t, &net.AddrError{}, err) + + _, err = ListenAddr("1.1.1.1:1111", &tls.Config{}, &Config{}) + require.Error(t, err) + require.IsType(t, &net.OpError{}, err) + + ln, err := ListenAddr("127.0.0.1:0", &tls.Config{}, &Config{}) + require.NoError(t, err) + defer ln.Close() +} + +func TestServerPacketDropping(t *testing.T) { + t.Run("destination connection ID too short", func(t *testing.T) { + conn := newUDPConnLocalhost(t) + testServerDroppedPacket(t, + conn, + getValidInitialPacket(t, conn.LocalAddr(), randConnID(5), randConnID(7)), + logging.PacketTypeInitial, + logging.PacketDropUnexpectedPacket, + ) + }) + + t.Run("Initial packet too small", func(t *testing.T) { + conn := newUDPConnLocalhost(t) + p := getLongHeaderPacket(t, + conn.LocalAddr(), + &wire.ExtendedHeader{ + Header: wire.Header{ + Type: protocol.PacketTypeInitial, + DestConnectionID: randConnID(8), + Version: protocol.Version1, + }, + PacketNumberLen: 2, + }, + make([]byte, protocol.MinInitialPacketSize-100), + ) + require.Greater(t, len(p.data), protocol.MinInitialPacketSize-100) + require.Less(t, len(p.data), protocol.MinInitialPacketSize) + testServerDroppedPacket(t, + conn, + p, + logging.PacketTypeInitial, + logging.PacketDropUnexpectedPacket, + ) + }) + + // we should not send a Version Negotiation packet if the packet is smaller than 1200 bytes + t.Run("packet of unknown version, too small", func(t *testing.T) { + conn := newUDPConnLocalhost(t) + p := getLongHeaderPacket(t, + conn.LocalAddr(), + &wire.ExtendedHeader{ + Header: wire.Header{ + Type: protocol.PacketTypeInitial, + DestConnectionID: randConnID(8), + Version: 0x42, + }, + PacketNumberLen: 2, + }, + make([]byte, protocol.MinUnknownVersionPacketSize-100), + ) + require.Greater(t, len(p.data), protocol.MinUnknownVersionPacketSize-100) + require.Less(t, len(p.data), protocol.MinUnknownVersionPacketSize) + testServerDroppedPacket(t, + conn, + p, + logging.PacketTypeNotDetermined, + logging.PacketDropUnexpectedPacket, + ) + }) + + t.Run("not an Initial packet", func(t *testing.T) { + conn := newUDPConnLocalhost(t) + testServerDroppedPacket(t, + conn, + getLongHeaderPacket(t, + conn.LocalAddr(), + &wire.ExtendedHeader{ + Header: wire.Header{ + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + Version: protocol.Version1, + }, + PacketNumberLen: 2, + }, + nil, + ), + logging.PacketTypeHandshake, + logging.PacketDropUnexpectedPacket, + ) + }) + + // as a server, we should never receive a Version Negotiation packet + t.Run("Version Negotiation packet", func(t *testing.T) { + conn := newUDPConnLocalhost(t) + data := wire.ComposeVersionNegotiation( + protocol.ArbitraryLenConnectionID{1, 2, 3, 4}, + protocol.ArbitraryLenConnectionID{4, 3, 2, 1}, + []protocol.Version{1, 2, 3}, + ) + testServerDroppedPacket(t, + conn, + receivedPacket{ + remoteAddr: conn.LocalAddr(), + data: data, + buffer: getPacketBuffer(), + }, + logging.PacketTypeVersionNegotiation, + logging.PacketDropUnexpectedPacket, + ) + }) +} + +func testServerDroppedPacket(t *testing.T, + conn *net.UDPConn, + p receivedPacket, + expectedPacketType logging.PacketType, + expectedDropReason logging.PacketDropReason, +) { + readChan := make(chan struct{}) + go func() { + defer close(readChan) + conn.ReadFrom(make([]byte, 1000)) + }() + mockCtrl := gomock.NewController(t) + tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) + server := newTestServer(t, &serverOpts{tracer: tracer}) + + mockTracer.EXPECT().DroppedPacket(p.remoteAddr, expectedPacketType, p.Size(), expectedDropReason) + server.handlePacket(p) + + select { + case <-readChan: + t.Fatal("didn't expect to receive a packet") + case <-time.After(scaleDuration(5 * time.Millisecond)): + } +} + +func TestServerVersionNegotiation(t *testing.T) { + t.Run("enabled", func(t *testing.T) { + testServerVersionNegotiation(t, true) + }) + t.Run("disabled", func(t *testing.T) { + testServerVersionNegotiation(t, false) + }) +} + +func testServerVersionNegotiation(t *testing.T, enabled bool) { + mockCtrl := gomock.NewController(t) + conn := newUDPConnLocalhost(t) + tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) + server := newTestServer(t, &serverOpts{ + tracer: tracer, + disableVersionNegotiation: !enabled, + }) + + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) + destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) + packet := getLongHeaderPacket(t, conn.LocalAddr(), + &wire.ExtendedHeader{ + Header: wire.Header{ + Type: protocol.PacketTypeHandshake, + SrcConnectionID: srcConnID, + DestConnectionID: destConnID, + Version: 0x42, + }, + PacketNumberLen: protocol.PacketNumberLen4, + }, + make([]byte, protocol.MinUnknownVersionPacketSize), ) - getPacket := func(hdr *wire.Header, p []byte) receivedPacket { - buf := getPacketBuffer() - hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 - var err error - buf.Data, err = (&wire.ExtendedHeader{ - Header: *hdr, - PacketNumber: 0x42, - PacketNumberLen: protocol.PacketNumberLen4, - }).Append(buf.Data, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - n := len(buf.Data) - buf.Data = append(buf.Data, p...) - data := buf.Data - sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version) - _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) - data = data[:len(data)+16] - sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n]) - return receivedPacket{ - rcvTime: time.Now(), - remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456}, - data: data, - buffer: buf, + switch enabled { + case true: + mockTracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()) + case false: + mockTracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeNotDetermined, packet.Size(), logging.PacketDropUnexpectedVersion) + } + + written := make(chan []byte, 1) + go func() { + b := make([]byte, 1500) + n, _, _ := conn.ReadFrom(b) + written <- b[:n] + }() + server.handlePacket(packet) + + switch enabled { + case true: + select { + case b := <-written: + require.True(t, wire.IsVersionNegotiationPacket(b)) + dest, src, versions, err := wire.ParseVersionNegotiationPacket(b) + require.NoError(t, err) + require.Equal(t, protocol.ArbitraryLenConnectionID(srcConnID.Bytes()), dest) + require.Equal(t, protocol.ArbitraryLenConnectionID(destConnID.Bytes()), src) + require.NotContains(t, versions, protocol.Version(0x42)) + case <-time.After(time.Second): + t.Fatal("timeout") + } + case false: + select { + case <-written: + t.Fatal("expected no version negotiation packet") + case <-time.After(scaleDuration(10 * time.Millisecond)): } } +} - getInitial := func(destConnID protocol.ConnectionID) receivedPacket { - senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} - hdr := &wire.Header{ - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), - DestConnectionID: destConnID, - Version: protocol.Version1, - } - p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - p.buffer = getPacketBuffer() - p.remoteAddr = senderAddr - return p - } +func TestServerRetry(t *testing.T) { + mockCtrl := gomock.NewController(t) + tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) + server := newTestServer(t, &serverOpts{tracer: tracer, useRetry: true}) - getInitialWithRandomDestConnID := func() receivedPacket { - b := make([]byte, 10) - _, err := rand.Read(b) - Expect(err).ToNot(HaveOccurred()) + conn := newUDPConnLocalhost(t) - return getInitial(protocol.ParseConnectionID(b)) - } - - parseHeader := func(data []byte) *wire.Header { - hdr, _, _, err := wire.ParsePacket(data) - Expect(err).ToNot(HaveOccurred()) - return hdr - } - - checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) { - replyHdr := parseHeader(b) - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID)) - _, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) - extHdr, err := unpackLongHeader(opener, replyHdr, b) - Expect(err).ToNot(HaveOccurred()) - data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) - Expect(err).ToNot(HaveOccurred()) - _, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := f.(*wire.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode)) - Expect(ccf.ReasonPhrase).To(BeEmpty()) - } - - BeforeEach(func() { - conn = NewMockPacketConn(mockCtrl) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() - wait := make(chan struct{}) - conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) { - <-wait - return 0, nil, errors.New("done") - }).MaxTimes(1) - conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) error { - close(wait) - conn.EXPECT().SetReadDeadline(time.Time{}) - return nil - }).MaxTimes(1) - tlsConf = testdata.GetTLSConfig() - tlsConf.NextProtos = []string{"proto1"} - }) - - It("errors when no tls.Config is given", func() { - _, err := ListenAddr("localhost:0", nil, nil) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set")) - }) - - It("errors when the Config contains an invalid version", func() { - version := protocol.Version(0x1234) - _, err := Listen(nil, tlsConf, &Config{Versions: []protocol.Version{version}}) - Expect(err).To(MatchError("invalid QUIC version: 0x1234")) - }) - - It("fills in default values if options are not set in the Config", func() { - ln, err := Listen(conn, tlsConf, &Config{}) - Expect(err).ToNot(HaveOccurred()) - server := ln.baseServer - Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) - Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) - Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) - Expect(server.config.KeepAlivePeriod).To(BeZero()) - // stop the listener - Expect(ln.Close()).To(Succeed()) - }) - - It("setups with the right values", func() { - supportedVersions := []protocol.Version{protocol.Version1} - config := Config{ - Versions: supportedVersions, - HandshakeIdleTimeout: 1337 * time.Hour, - MaxIdleTimeout: 42 * time.Minute, - KeepAlivePeriod: 5 * time.Second, - } - ln, err := Listen(conn, tlsConf, &config) - Expect(err).ToNot(HaveOccurred()) - server := ln.baseServer - Expect(server.connHandler).ToNot(BeNil()) - Expect(server.config.Versions).To(Equal(supportedVersions)) - Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) - Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) - Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) - // stop the listener - Expect(ln.Close()).To(Succeed()) - }) - - It("listens on a given address", func() { - addr := "127.0.0.1:13579" - ln, err := ListenAddr(addr, tlsConf, &Config{}) - Expect(err).ToNot(HaveOccurred()) - Expect(ln.Addr().String()).To(Equal(addr)) - // stop the listener - Expect(ln.Close()).To(Succeed()) - }) - - It("errors if given an invalid address", func() { - addr := "127.0.0.1" - _, err := ListenAddr(addr, tlsConf, &Config{}) - Expect(err).To(BeAssignableToTypeOf(&net.AddrError{})) - }) - - It("errors if given an invalid address", func() { - addr := "1.1.1.1:1111" - _, err := ListenAddr(addr, tlsConf, &Config{}) - Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) - }) - - Context("server accepting connections that completed the handshake", func() { - var ( - tr *Transport - serv *baseServer - phm *MockPacketHandlerManager - tracer *mocklogging.MockTracer - ) - - BeforeEach(func() { - var t *logging.Tracer - t, tracer = mocklogging.NewMockTracer(mockCtrl) - tr = &Transport{Conn: conn, Tracer: t} - ln, err := tr.Listen(tlsConf, nil) - Expect(err).ToNot(HaveOccurred()) - serv = ln.baseServer - phm = NewMockPacketHandlerManager(mockCtrl) - serv.connHandler = phm - }) - - AfterEach(func() { - tracer.EXPECT().Close() - tr.Close() - }) - - Context("handling packets", func() { - It("drops Initial packets with a too short connection ID", func() { - p := getPacket(&wire.Header{ - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), - Version: serv.config.Versions[0], - }, nil) - tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) - serv.handlePacket(p) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }) - - It("drops too small Initial", func() { - p := getPacket(&wire.Header{ - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize-100)) - tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) - serv.handlePacket(p) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }) - - It("drops non-Initial packets", func() { - p := getPacket(&wire.Header{ - Type: protocol.PacketTypeHandshake, - Version: serv.config.Versions[0], - }, []byte("invalid")) - tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket) - serv.handlePacket(p) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }) - - It("passes packets to existing connections", func() { - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - p := getPacket(&wire.Header{ - Type: protocol.PacketTypeInitial, - DestConnectionID: connID, - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize)) - conn := NewMockPacketHandler(mockCtrl) - phm.EXPECT().Get(connID).Return(conn, true) - handled := make(chan struct{}) - conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) - serv.handlePacket(p) - Eventually(handled).Should(BeClosed()) - }) - - It("creates a connection when the token is accepted", func() { - serv.verifySourceAddress = func(net.Addr) bool { return true } - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - retryToken, err := serv.tokenGenerator.NewRetryToken( - raddr, - protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), - protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), - ) - Expect(err).ToNot(HaveOccurred()) - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - hdr := &wire.Header{ - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), - DestConnectionID: connID, - Version: protocol.Version1, - Token: retryToken, - } - p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - p.remoteAddr = raddr - run := make(chan struct{}) - var token protocol.StatelessResetToken - rand.Read(token[:]) - - var newConnID protocol.ConnectionID - conn := NewMockQUICConn(mockCtrl) - serv.newConn = func( - _ context.Context, - _ context.CancelCauseFunc, - _ sendConn, - _ connRunner, - origDestConnID protocol.ConnectionID, - retrySrcConnID *protocol.ConnectionID, - clientDestConnID protocol.ConnectionID, - destConnID protocol.ConnectionID, - srcConnID protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))) - Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) - Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) - Expect(destConnID).To(Equal(hdr.SrcConnectionID)) - // make sure we're using a server-generated connection ID - Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) - Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) - newConnID = srcConnID - conn.EXPECT().handlePacket(p) - conn.EXPECT().run().Do(func() error { close(run); return nil }) - conn.EXPECT().Context().Return(context.Background()) - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) - return conn - } - phm.EXPECT().Get(connID) - phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool { - Expect(cid).To(Equal(newConnID)) - return true - }) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - serv.handlePacket(p) - // the Handshake packet is written by the connection. - // Make sure there are no Write calls on the packet conn. - time.Sleep(50 * time.Millisecond) - close(done) - }() - // make sure we're using a server-generated connection ID - Eventually(run).Should(BeClosed()) - Eventually(done).Should(BeClosed()) - // shutdown - conn.EXPECT().closeWithTransportError(gomock.Any()) - }) - - It("sends a Version Negotiation Packet for unsupported versions", func() { - srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) - destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) - packet := getPacket(&wire.Header{ - Type: protocol.PacketTypeHandshake, - SrcConnectionID: srcConnID, - DestConnectionID: destConnID, - Version: 0x42, - }, make([]byte, protocol.MinUnknownVersionPacketSize)) - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - packet.remoteAddr = raddr - tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.Version) { - Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) - Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) - }) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue()) - dest, src, versions, err := wire.ParseVersionNegotiationPacket(b) - Expect(err).ToNot(HaveOccurred()) - Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) - Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) - Expect(versions).ToNot(ContainElement(protocol.Version(0x42))) - return len(b), nil - }) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't send a Version Negotiation packets if sending them is disabled", func() { - serv.disableVersionNegotiation = true - srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) - destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) - packet := getPacket(&wire.Header{ - Type: protocol.PacketTypeHandshake, - SrcConnectionID: srcConnID, - DestConnectionID: destConnID, - Version: 0x42, - }, make([]byte, protocol.MinUnknownVersionPacketSize)) - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - packet.remoteAddr = raddr - done := make(chan struct{}) - tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, packet.Size(), logging.PacketDropUnexpectedVersion) - serv.handlePacket(packet) - Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed()) - }) - - It("ignores Version Negotiation packets", func() { - data := wire.ComposeVersionNegotiation( - protocol.ArbitraryLenConnectionID{1, 2, 3, 4}, - protocol.ArbitraryLenConnectionID{4, 3, 2, 1}, - []protocol.Version{1, 2, 3}, - ) - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - done := make(chan struct{}) - tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { - close(done) - }) - serv.handlePacket(receivedPacket{ - remoteAddr: raddr, - data: data, - buffer: getPacketBuffer(), - }) - Eventually(done).Should(BeClosed()) - // make sure no other packet is sent - time.Sleep(scaleDuration(20 * time.Millisecond)) - }) - - It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() { - srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) - destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) - p := getPacket(&wire.Header{ - Type: protocol.PacketTypeHandshake, - SrcConnectionID: srcConnID, - DestConnectionID: destConnID, - Version: 0x42, - }, make([]byte, protocol.MinUnknownVersionPacketSize-50)) - Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize)) - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - p.remoteAddr = raddr - done := make(chan struct{}) - tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { - close(done) - }) - serv.handlePacket(p) - Eventually(done).Should(BeClosed()) - // make sure no other packet is sent - time.Sleep(scaleDuration(20 * time.Millisecond)) - }) - - It("replies with a Retry packet, if a token is required", func() { - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - var called bool - serv.verifySourceAddress = func(addr net.Addr) bool { - Expect(addr).To(Equal(raddr)) - called = true - return true - } - hdr := &wire.Header{ - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), - DestConnectionID: connID, - Version: protocol.Version1, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = raddr - tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) - Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(replyHdr.Token).ToNot(BeEmpty()) - }) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - replyHdr := parseHeader(b) - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) - Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(replyHdr.Token).ToNot(BeEmpty()) - Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:])) - return len(b), nil - }) - phm.EXPECT().Get(connID) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - Expect(called).To(BeTrue()) - }) - - It("creates a connection, if no token is required", func() { - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - hdr := &wire.Header{ - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), - DestConnectionID: connID, - Version: protocol.Version1, - } - p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - run := make(chan struct{}) - var token protocol.StatelessResetToken - rand.Read(token[:]) - - var newConnID protocol.ConnectionID - conn := NewMockQUICConn(mockCtrl) - serv.newConn = func( - _ context.Context, - _ context.CancelCauseFunc, - _ sendConn, - _ connRunner, - origDestConnID protocol.ConnectionID, - retrySrcConnID *protocol.ConnectionID, - clientDestConnID protocol.ConnectionID, - destConnID protocol.ConnectionID, - srcConnID protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) - Expect(retrySrcConnID).To(BeNil()) - Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) - Expect(destConnID).To(Equal(hdr.SrcConnectionID)) - // make sure we're using a server-generated connection ID - Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) - Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) - newConnID = srcConnID - conn.EXPECT().handlePacket(p) - conn.EXPECT().run().Do(func() error { close(run); return nil }) - conn.EXPECT().Context().Return(context.Background()) - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) - return conn - } - gomock.InOrder( - phm.EXPECT().Get(connID), - phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool { - Expect(c).To(Equal(newConnID)) - return true - }), - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - serv.handlePacket(p) - // the Handshake packet is written by the connection - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - close(done) - }() - // make sure we're using a server-generated connection ID - Eventually(run).Should(BeClosed()) - Eventually(done).Should(BeClosed()) - // shutdown - conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) - }) - - It("drops packets if the receive queue is full", func() { - serv.verifySourceAddress = func(net.Addr) bool { return false } - - phm.EXPECT().Get(gomock.Any()).AnyTimes() - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() - - acceptConn := make(chan struct{}) - var counter atomic.Uint32 - serv.newConn = func( - _ context.Context, - _ context.CancelCauseFunc, - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - <-acceptConn - counter.Add(1) - conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) - conn.EXPECT().run().MaxTimes(1) - conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1) - // shutdown - conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) - return conn - } - - p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - serv.handlePacket(p) - tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1) - var wg sync.WaitGroup - for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ { - wg.Add(1) - go func() { - defer GinkgoRecover() - defer wg.Done() - serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))) - }() - } - wg.Wait() - - close(acceptConn) - Eventually( - func() uint32 { return counter.Load() }, - scaleDuration(100*time.Millisecond), - ).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) - Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) - }) - - It("only creates a single connection for a duplicate Initial", func() { - done := make(chan struct{}) - serv.newConn = func( - _ context.Context, - _ context.CancelCauseFunc, - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().handlePacket(gomock.Any()) - conn.EXPECT().closeWithTransportError(qerr.ConnectionRefused).Do(func(qerr.TransportErrorCode) { - close(done) - }) - return conn - } - - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) - p := getInitial(connID) - phm.EXPECT().Get(connID) - phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision - Expect(serv.handlePacketImpl(p)).To(BeTrue()) - Eventually(done).Should(BeClosed()) - }) - - It("limits the number of unvalidated handshakes", func() { - const limit = 3 - limiter := rate.NewLimiter(0, limit) - serv.verifySourceAddress = func(net.Addr) bool { return !limiter.Allow() } - - phm.EXPECT().Get(gomock.Any()).AnyTimes() - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes() - - connChan := make(chan *MockQUICConn, 1) - var wg sync.WaitGroup - wg.Add(limit) - done := make(chan struct{}) - serv.newConn = func( - _ context.Context, - _ context.CancelCauseFunc, - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - conn := <-connChan - conn.EXPECT().handlePacket(gomock.Any()) - conn.EXPECT().run() - conn.EXPECT().Context().Return(context.Background()) - conn.EXPECT().HandshakeComplete().DoAndReturn(func() <-chan struct{} { wg.Done(); return done }) - return conn - } - - // Initiate the maximum number of allowed connection attempts. - for i := 0; i < limit; i++ { - conn := NewMockQUICConn(mockCtrl) - connChan <- conn - serv.handlePacket(getInitialWithRandomDestConnID()) - } - - // Now initiate another connection attempt. - p := getInitialWithRandomDestConnID() - tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { - defer GinkgoRecover() - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) - }) - conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer GinkgoRecover() - defer close(done) - hdr, _, _, err := wire.ParsePacket(b) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) - return len(b), nil - }) - serv.handlePacket(p) - Eventually(done).Should(BeClosed()) - - for i := 0; i < limit; i++ { - _, err := serv.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - } - wg.Wait() - }) - }) - - Context("token validation", func() { - It("decodes the token from the token field", func() { - serv.newConn = func( - _ context.Context, - _ context.CancelCauseFunc, - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - c := NewMockQUICConn(mockCtrl) - c.EXPECT().handlePacket(gomock.Any()) - c.EXPECT().run() - c.EXPECT().HandshakeComplete() - ctx, cancel := context.WithCancel(context.Background()) - cancel() - c.EXPECT().Context().Return(ctx) - return c - } - raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} - token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) - Expect(err).ToNot(HaveOccurred()) - packet := getPacket(&wire.Header{ - Type: protocol.PacketTypeInitial, - Token: token, - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = raddr - conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) - - done := make(chan struct{}) - phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool { - close(done) - return true - }) - phm.EXPECT().Remove(gomock.Any()).AnyTimes() - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { - serv.verifySourceAddress = func(net.Addr) bool { return true } - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) - Expect(err).ToNot(HaveOccurred()) - hdr := &wire.Header{ - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), - DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), - Token: token, - Version: protocol.Version1, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - packet.remoteAddr = raddr - tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(frames).To(HaveLen(1)) - Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := frames[0].(*logging.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) - }) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - checkConnectionCloseError(b, hdr, qerr.InvalidToken) - return len(b), nil - }) - phm.EXPECT().Get(gomock.Any()) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("sends an INVALID_TOKEN error, if an expired retry token is received", func() { - serv.verifySourceAddress = func(net.Addr) bool { return true } - serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout - Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond)) - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) - Expect(err).ToNot(HaveOccurred()) - time.Sleep(2 * time.Millisecond) // make sure the token is expired - hdr := &wire.Header{ - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), - DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), - Token: token, - Version: protocol.Version1, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = raddr - tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(frames).To(HaveLen(1)) - Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := frames[0].(*logging.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) - }) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - checkConnectionCloseError(b, hdr, qerr.InvalidToken) - return len(b), nil - }) - phm.EXPECT().Get(gomock.Any()) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() { - serv.verifySourceAddress = func(net.Addr) bool { return true } - token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}) - Expect(err).ToNot(HaveOccurred()) - hdr := &wire.Header{ - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), - DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), - Token: token, - Version: protocol.Version1, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - packet.remoteAddr = raddr - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - replyHdr := parseHeader(b) - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) - return len(b), nil - }) - phm.EXPECT().Get(gomock.Any()) - serv.handlePacket(packet) - // make sure there are no Write calls on the packet conn - Eventually(done).Should(BeClosed()) - }) - - It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() { - serv.verifySourceAddress = func(net.Addr) bool { return true } - serv.maxTokenAge = time.Millisecond - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - token, err := serv.tokenGenerator.NewToken(raddr) - Expect(err).ToNot(HaveOccurred()) - time.Sleep(2 * time.Millisecond) // make sure the token is expired - hdr := &wire.Header{ - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), - DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), - Token: token, - Version: protocol.Version1, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = raddr - tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) - }) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - return len(b), nil - }) - phm.EXPECT().Get(gomock.Any()) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) - Expect(err).ToNot(HaveOccurred()) - hdr := &wire.Header{ - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), - DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), - Token: token, - Version: protocol.Version1, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet - packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - done := make(chan struct{}) - tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) - phm.EXPECT().Get(gomock.Any()) - serv.handlePacket(packet) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - Eventually(done).Should(BeClosed()) - }) - }) - - Context("accepting connections", func() { - It("returns Accept when closed", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := serv.Accept(context.Background()) - Expect(err).To(MatchError(ErrServerClosed)) - Expect(err).To(MatchError(net.ErrClosed)) - close(done) - }() - - serv.Close() - Eventually(done).Should(BeClosed()) - }) - - It("returns immediately, if an error occurred before", func() { - serv.Close() - for i := 0; i < 3; i++ { - _, err := serv.Accept(context.Background()) - Expect(err).To(MatchError(ErrServerClosed)) - } - }) - - PIt("closes connection that are still handshaking after Close", func() { - serv.Close() - - destroyed := make(chan struct{}) - serv.newConn = func( - _ context.Context, - _ context.CancelCauseFunc, - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - conf *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().handlePacket(gomock.Any()) - conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) }) - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) - conn.EXPECT().run().MaxTimes(1) - conn.EXPECT().Context().Return(context.Background()) - return conn - } - phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - serv.handleInitialImpl( - receivedPacket{buffer: getPacketBuffer()}, - &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, - ) - Eventually(destroyed).Should(BeClosed()) - }) - - It("returns when the context is canceled", func() { - ctx, cancel := context.WithCancel(context.Background()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := serv.Accept(ctx) - Expect(err).To(MatchError("context canceled")) - close(done) - }() - - Consistently(done).ShouldNot(BeClosed()) - cancel() - Eventually(done).Should(BeClosed()) - }) - - It("uses the config returned by GetConfigClient", func() { - conn := NewMockQUICConn(mockCtrl) - - conf := &Config{MaxIncomingStreams: 1234} - serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }}) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - s, err := serv.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(s).To(Equal(conn)) - close(done) - }() - - handshakeChan := make(chan struct{}) - serv.newConn = func( - _ context.Context, - _ context.CancelCauseFunc, - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - conf *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234)) - conn.EXPECT().handlePacket(gomock.Any()) - conn.EXPECT().HandshakeComplete().Return(handshakeChan) - conn.EXPECT().run() - conn.EXPECT().Context().Return(context.Background()) - return conn - } - phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - serv.handleInitialImpl( - receivedPacket{buffer: getPacketBuffer()}, - &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, - ) - Consistently(done).ShouldNot(BeClosed()) - close(handshakeChan) // complete the handshake - Eventually(done).Should(BeClosed()) - }) - - It("rejects a connection attempt when GetConfigClient returns an error", func() { - serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }}) - - phm.EXPECT().Get(gomock.Any()) - done := make(chan struct{}) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - rejectHdr := parseHeader(b) - Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) - return len(b), nil - }) - serv.handleInitialImpl( - receivedPacket{buffer: getPacketBuffer()}, - &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1}, - ) - Eventually(done).Should(BeClosed()) - }) - - It("accepts new connections when the handshake completes", func() { - conn := NewMockQUICConn(mockCtrl) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - s, err := serv.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(s).To(Equal(conn)) - close(done) - }() - - handshakeChan := make(chan struct{}) - serv.newConn = func( - _ context.Context, - _ context.CancelCauseFunc, - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - conn.EXPECT().handlePacket(gomock.Any()) - conn.EXPECT().HandshakeComplete().Return(handshakeChan) - conn.EXPECT().run() - conn.EXPECT().Context().Return(context.Background()) - return conn - } - phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - serv.handleInitialImpl( - receivedPacket{buffer: getPacketBuffer()}, - &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, - ) - Consistently(done).ShouldNot(BeClosed()) - close(handshakeChan) // complete the handshake - Eventually(done).Should(BeClosed()) - }) - }) - }) - - Context("server accepting connections that haven't completed the handshake", func() { - var ( - serv *EarlyListener - phm *MockPacketHandlerManager - ) - - BeforeEach(func() { - var err error - serv, err = ListenEarly(conn, tlsConf, nil) - Expect(err).ToNot(HaveOccurred()) - phm = NewMockPacketHandlerManager(mockCtrl) - serv.baseServer.connHandler = phm - }) - - AfterEach(func() { - serv.Close() - }) - - It("accepts new connections when they become ready", func() { - conn := NewMockQUICConn(mockCtrl) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - s, err := serv.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(s).To(Equal(conn)) - close(done) - }() - - ready := make(chan struct{}) - serv.baseServer.newConn = func( - _ context.Context, - _ context.CancelCauseFunc, - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - conn.EXPECT().handlePacket(gomock.Any()) - conn.EXPECT().run() - conn.EXPECT().earlyConnReady().Return(ready) - conn.EXPECT().Context().Return(context.Background()) - return conn - } - phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - serv.baseServer.handleInitialImpl( - receivedPacket{buffer: getPacketBuffer()}, - &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, - ) - Consistently(done).ShouldNot(BeClosed()) - close(ready) - Eventually(done).Should(BeClosed()) - }) - - It("rejects new connection attempts if the accept queue is full", func() { - connChan := make(chan *MockQUICConn, 1) - var wg sync.WaitGroup // to make sure the test fully completes - wg.Add(protocol.MaxAcceptQueueSize) - serv.baseServer.newConn = func( - _ context.Context, - _ context.CancelCauseFunc, - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - ready := make(chan struct{}) - close(ready) - conn := <-connChan - conn.EXPECT().handlePacket(gomock.Any()) - conn.EXPECT().run().Do(func() error { wg.Done(); return nil }) - conn.EXPECT().earlyConnReady().Return(ready) - conn.EXPECT().Context().Return(context.Background()) - return conn - } - - phm.EXPECT().Get(gomock.Any()).AnyTimes() - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize) - for i := 0; i < protocol.MaxAcceptQueueSize; i++ { - conn := NewMockQUICConn(mockCtrl) - connChan <- conn - serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) - } - - Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize)) - wg.Wait() - wg.Add(1) - - rejected := make(chan struct{}) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(qerr.TransportErrorCode) { - close(rejected) - }) - connChan <- conn - serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) - Eventually(rejected).Should(BeClosed()) - }) - - It("doesn't accept new connections if they were closed in the mean time", func() { - p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) - ctx, cancel := context.WithCancel(context.Background()) - connCreated := make(chan struct{}) - conn := NewMockQUICConn(mockCtrl) - serv.baseServer.newConn = func( - _ context.Context, - _ context.CancelCauseFunc, - _ sendConn, - runner connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - conn.EXPECT().handlePacket(p) - conn.EXPECT().run() - conn.EXPECT().earlyConnReady() - conn.EXPECT().Context().Return(ctx) - close(connCreated) - return conn - } - - phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - serv.baseServer.handlePacket(p) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - Eventually(connCreated).Should(BeClosed()) - cancel() - time.Sleep(scaleDuration(200 * time.Millisecond)) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - serv.Accept(context.Background()) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - - // make the go routine return - Expect(serv.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - }) - - Context("0-RTT", func() { - var ( - tr *Transport - serv *baseServer - phm *MockPacketHandlerManager - tracer *mocklogging.MockTracer - ) - - BeforeEach(func() { - var t *logging.Tracer - t, tracer = mocklogging.NewMockTracer(mockCtrl) - tr = &Transport{Conn: conn, Tracer: t} - ln, err := tr.ListenEarly(tlsConf, nil) - Expect(err).ToNot(HaveOccurred()) - phm = NewMockPacketHandlerManager(mockCtrl) - serv = ln.baseServer - serv.connHandler = phm - }) - - AfterEach(func() { - tracer.EXPECT().Close() - Expect(tr.Close()).To(Succeed()) - }) - - It("passes packets to existing connections", func() { - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - p := getPacket(&wire.Header{ - Type: protocol.PacketType0RTT, - DestConnectionID: connID, - Version: serv.config.Versions[0], - }, make([]byte, 100)) - conn := NewMockPacketHandler(mockCtrl) - phm.EXPECT().Get(connID).Return(conn, true) - handled := make(chan struct{}) - conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) - serv.handlePacket(p) - Eventually(handled).Should(BeClosed()) - }) - - It("queues 0-RTT packets, up to Max0RTTQueueSize", func() { - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - - var zeroRTTPackets []receivedPacket - - for i := 0; i < protocol.Max0RTTQueueLen; i++ { - p := getPacket(&wire.Header{ - Type: protocol.PacketType0RTT, - DestConnectionID: connID, - Version: serv.config.Versions[0], - }, make([]byte, 100+i)) - phm.EXPECT().Get(connID) - serv.handlePacket(p) - zeroRTTPackets = append(zeroRTTPackets, p) - } - - // send one more packet, this one should be dropped - p := getPacket(&wire.Header{ - Type: protocol.PacketType0RTT, - DestConnectionID: connID, - Version: serv.config.Versions[0], - }, make([]byte, 200)) - phm.EXPECT().Get(connID) - tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) - serv.handlePacket(p) - - initial := getPacket(&wire.Header{ + packet := getLongHeaderPacket(t, conn.LocalAddr(), + &wire.ExtendedHeader{ + Header: wire.Header{ Type: protocol.PacketTypeInitial, - DestConnectionID: connID, - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize)) - called := make(chan struct{}) - serv.newConn = func( - _ context.Context, - _ context.CancelCauseFunc, - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ ConnectionIDGenerator, - _ *statelessResetter, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ *logging.ConnectionTracer, - _ utils.Logger, - _ protocol.Version, - ) quicConn { - conn := NewMockQUICConn(mockCtrl) - var calls []any - calls = append(calls, conn.EXPECT().handlePacket(initial)) - for _, p := range zeroRTTPackets { - calls = append(calls, conn.EXPECT().handlePacket(p)) - } - gomock.InOrder(calls...) - conn.EXPECT().run() - conn.EXPECT().earlyConnReady() - conn.EXPECT().Context().Return(context.Background()) - close(called) - // shutdown - conn.EXPECT().closeWithTransportError(gomock.Any()) - return conn - } + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + Version: protocol.Version1, + }, + PacketNumberLen: protocol.PacketNumberLen4, + }, + make([]byte, protocol.MinUnknownVersionPacketSize), + ) + argsChan := make(chan sentPacketCallArgs, 1) + mockTracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do( + func(_ net.Addr, hdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { + argsChan <- sentPacketCallArgs{hdr: hdr} + }, + ) + server.handlePacket(packet) + checkRetry(t, conn, argsChan, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})) +} - phm.EXPECT().Get(connID) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - serv.handlePacket(initial) - Eventually(called).Should(BeClosed()) +func TestServerTokenValidation(t *testing.T) { + var tokenGeneratorKey handshake.TokenProtectorKey + rand.Read(tokenGeneratorKey[:]) + tg := handshake.NewTokenGenerator(tokenGeneratorKey) + + t.Run("retry token with invalid address", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + token, err := tg.NewRetryToken( + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, + protocol.ConnectionID{}, + protocol.ConnectionID{}, + ) + require.NoError(t, err) + tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) + server := newTestServer(t, &serverOpts{ + useRetry: true, + tracer: tracer, + tokenGeneratorKey: tokenGeneratorKey, }) - It("limits the number of queues", func() { - for i := 0; i < protocol.Max0RTTQueues; i++ { - b := make([]byte, 16) - rand.Read(b) - connID := protocol.ParseConnectionID(b) - p := getPacket(&wire.Header{ - Type: protocol.PacketType0RTT, - DestConnectionID: connID, - Version: serv.config.Versions[0], - }, make([]byte, 100+i)) - phm.EXPECT().Get(connID) - serv.handlePacket(p) - } - - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - p := getPacket(&wire.Header{ - Type: protocol.PacketType0RTT, - DestConnectionID: connID, - Version: serv.config.Versions[0], - }, make([]byte, 200)) - phm.EXPECT().Get(connID) - dropped := make(chan struct{}) - tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { - close(dropped) - }) - serv.handlePacket(p) - Eventually(dropped).Should(BeClosed()) - }) - - It("drops queues after a while", func() { - now := time.Now() - - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - p := getPacket(&wire.Header{ - Type: protocol.PacketType0RTT, - DestConnectionID: connID, - Version: serv.config.Versions[0], - }, make([]byte, 200)) - p.rcvTime = now - - connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9}) - p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2) - p2 := getPacket(&wire.Header{ - Type: protocol.PacketType0RTT, - DestConnectionID: connID2, - Version: serv.config.Versions[0], - }, make([]byte, 300)) - p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet - - dropped1 := make(chan struct{}) - dropped2 := make(chan struct{}) - // need to register the call before handling the packet to avoid race condition - gomock.InOrder( - tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { - close(dropped1) - }), - tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { - close(dropped2) - }), - ) - - phm.EXPECT().Get(connID) - serv.handlePacket(p) - - // There's no cleanup Go routine. - // Cleanup is triggered when new packets are received. - - phm.EXPECT().Get(connID2) - serv.handlePacket(p2) - // make sure no cleanup is executed - Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed()) - - // There's no cleanup Go routine. - // Cleanup is triggered when new packets are received. - connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0}) - p3 := getPacket(&wire.Header{ - Type: protocol.PacketType0RTT, - DestConnectionID: connID3, - Version: serv.config.Versions[0], - }, make([]byte, 200)) - p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup - phm.EXPECT().Get(connID3) - serv.handlePacket(p3) - Eventually(dropped1).Should(BeClosed()) - Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed()) - - // make sure the second packet is also cleaned up - connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1}) - p4 := getPacket(&wire.Header{ - Type: protocol.PacketType0RTT, - DestConnectionID: connID4, - Version: serv.config.Versions[0], - }, make([]byte, 200)) - p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup - phm.EXPECT().Get(connID4) - serv.handlePacket(p4) - Eventually(dropped2).Should(BeClosed()) - }) + testServerTokenValidation(t, server, mockTracer, newUDPConnLocalhost(t), token, false, true, false) }) -}) + + t.Run("expired retry token", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + conn := newUDPConnLocalhost(t) + tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) + server := newTestServer(t, &serverOpts{ + useRetry: true, + tracer: tracer, + config: &Config{HandshakeIdleTimeout: time.Millisecond / 2}, + tokenGeneratorKey: tokenGeneratorKey, + }) + + token, err := tg.NewRetryToken(conn.LocalAddr(), protocol.ConnectionID{}, protocol.ConnectionID{}) + require.NoError(t, err) + // the maximum retry token age is equivalent to the handshake timeout + time.Sleep(time.Millisecond) // make sure the token is expired + testServerTokenValidation(t, server, mockTracer, conn, token, false, true, false) + }) + + // if the packet is corrupted, it will just be dropped (no INVALID_TOKEN nor Retry is sent) + t.Run("corrupted packet", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) + server := newTestServer(t, &serverOpts{ + useRetry: true, + tracer: tracer, + config: &Config{HandshakeIdleTimeout: time.Millisecond / 2}, + tokenGeneratorKey: tokenGeneratorKey, + }) + + conn := newUDPConnLocalhost(t) + token, err := tg.NewRetryToken(conn.LocalAddr(), protocol.ConnectionID{}, protocol.ConnectionID{}) + require.NoError(t, err) + time.Sleep(time.Millisecond) // make sure the token is expired + testServerTokenValidation(t, server, mockTracer, conn, token, true, false, true) + }) + + t.Run("invalid non-retry token", func(t *testing.T) { + var tokenGeneratorKey2 handshake.TokenProtectorKey + rand.Read(tokenGeneratorKey2[:]) + mockCtrl := gomock.NewController(t) + tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) + server := newTestServer(t, &serverOpts{ + tokenGeneratorKey: tokenGeneratorKey2, // use a different key + useRetry: true, + tracer: tracer, + maxTokenAge: time.Millisecond, + }) + + conn := newUDPConnLocalhost(t) + token, err := tg.NewToken(conn.LocalAddr()) + require.NoError(t, err) + time.Sleep(3 * time.Millisecond) // make sure the token is expired + testServerTokenValidation(t, server, mockTracer, conn, token, false, false, true) + }) + + t.Run("expired non-retry token", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) + server := newTestServer(t, &serverOpts{ + tokenGeneratorKey: tokenGeneratorKey, + useRetry: true, + tracer: tracer, + maxTokenAge: time.Millisecond, + }) + + conn := newUDPConnLocalhost(t) + token, err := tg.NewToken(conn.LocalAddr()) + require.NoError(t, err) + time.Sleep(3 * time.Millisecond) // make sure the token is expired + testServerTokenValidation(t, server, mockTracer, conn, token, false, false, true) + }) +} + +func testServerTokenValidation( + t *testing.T, + server *testServer, + mockTracer *mocklogging.MockTracer, + conn *net.UDPConn, + token []byte, + corruptedPacket bool, + expectInvalidTokenConnectionClose bool, + expectRetry bool, +) { + hdr := wire.Header{ + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), + Token: token, + Length: protocol.MinInitialPacketSize + protocol.ByteCount(protocol.PacketNumberLen4) + 16, + Version: protocol.Version1, + } + packet := getLongHeaderPacketEncrypted(t, + conn.LocalAddr(), + &wire.ExtendedHeader{Header: hdr, PacketNumberLen: protocol.PacketNumberLen4}, + make([]byte, protocol.MinInitialPacketSize), + ) + if corruptedPacket { + packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet + done := make(chan struct{}) + mockTracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do( + func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { + close(done) + }, + ) + server.handlePacket(packet) + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout") + } + return + } + + argsChan := make(chan sentPacketCallArgs, 1) + if expectInvalidTokenConnectionClose || expectRetry { + mockTracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do( + func(_ net.Addr, hdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { + argsChan <- sentPacketCallArgs{hdr: hdr, frames: frames} + }, + ) + } + + server.handlePacket(packet) + if expectInvalidTokenConnectionClose { + checkConnectionClose(t, conn, argsChan, hdr.DestConnectionID, hdr.SrcConnectionID, qerr.InvalidToken) + } + if expectRetry { + checkRetry(t, conn, argsChan, hdr.SrcConnectionID) + } +} + +type connConstructorArgs struct { + ctx context.Context + config *Config + origDestConnID protocol.ConnectionID + retrySrcConnID *protocol.ConnectionID + clientDestConnID protocol.ConnectionID + destConnID protocol.ConnectionID + srcConnID protocol.ConnectionID +} + +type connConstructorRecorder struct { + ch chan connConstructorArgs + + conns []quicConn +} + +func newConnConstructorRecorder(conns ...quicConn) *connConstructorRecorder { + return &connConstructorRecorder{ + ch: make(chan connConstructorArgs, len(conns)), + conns: conns, + } +} + +func (r *connConstructorRecorder) Args() <-chan connConstructorArgs { return r.ch } + +func (r *connConstructorRecorder) NewConn( + ctx context.Context, + _ context.CancelCauseFunc, + _ sendConn, + _ connRunner, + origDestConnID protocol.ConnectionID, + retrySrcConnID *protocol.ConnectionID, + clientDestConnID protocol.ConnectionID, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + _ ConnectionIDGenerator, + _ *statelessResetter, + config *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ *logging.ConnectionTracer, + _ utils.Logger, + _ protocol.Version, +) quicConn { + r.ch <- connConstructorArgs{ + ctx: ctx, + config: config, + origDestConnID: origDestConnID, + retrySrcConnID: retrySrcConnID, + clientDestConnID: clientDestConnID, + destConnID: destConnID, + srcConnID: srcConnID, + } + c := r.conns[0] + r.conns = r.conns[1:] + return c +} + +func TestServerCreateConnection(t *testing.T) { + t.Run("without retry", func(t *testing.T) { + testServerCreateConnection(t, false) + }) + t.Run("with retry", func(t *testing.T) { + testServerCreateConnection(t, true) + }) +} + +func testServerCreateConnection(t *testing.T, useRetry bool) { + mockCtrl := gomock.NewController(t) + tokenGeneratorKey := TokenGeneratorKey{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11} + tg := handshake.NewTokenGenerator(tokenGeneratorKey) + + server := newTestServer(t, &serverOpts{ + useRetry: useRetry, + tokenGeneratorKey: tokenGeneratorKey, + }) + + c := NewMockQUICConn(mockCtrl) + c.EXPECT().run() + c.EXPECT().Context().Return(context.Background()) + c.EXPECT().HandshakeComplete().Return(make(chan struct{})) + recorder := newConnConstructorRecorder(c) + server.newConn = recorder.NewConn + + conn := newUDPConnLocalhost(t) + var token []byte + if useRetry { + var err error + token, err = tg.NewRetryToken( + conn.LocalAddr(), + protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), + protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), + ) + require.NoError(t, err) + } + hdr := wire.Header{ + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), + Length: protocol.MinInitialPacketSize + protocol.ByteCount(protocol.PacketNumberLen4) + 16, + Token: token, + Version: protocol.Version1, + } + packet := getLongHeaderPacketEncrypted(t, + conn.LocalAddr(), + &wire.ExtendedHeader{Header: hdr, PacketNumberLen: protocol.PacketNumberLen4}, + make([]byte, protocol.MinInitialPacketSize), + ) + c.EXPECT().handlePacket(packet) + + server.handlePacket(packet) + + var args connConstructorArgs + select { + case args = <-recorder.Args(): + case <-time.After(time.Second): + t.Fatal("timeout") + } + + assert.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), args.destConnID) + assert.NotEqual(t, args.origDestConnID, args.srcConnID) + if useRetry { + assert.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), args.destConnID) + assert.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), args.origDestConnID) + assert.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), *args.retrySrcConnID) + } else { + assert.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), args.origDestConnID) + assert.Zero(t, args.retrySrcConnID) + } + + // shutdown + c.EXPECT().closeWithTransportError(ConnectionRefused) + c.EXPECT().destroy(gomock.Any()).AnyTimes() +} + +func TestServerClose(t *testing.T) { + mockCtrl := gomock.NewController(t) + var conns []quicConn + const numConns = 3 + done := make(chan struct{}, numConns) + for range numConns { + conn := NewMockQUICConn(mockCtrl) + conn.EXPECT().run() + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().Context().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) // doesn't complete handshake + conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { done <- struct{}{} }) + conns = append(conns, conn) + } + recorder := newConnConstructorRecorder(conns...) + server := newTestServer(t, &serverOpts{newConn: recorder.NewConn}) + + for range numConns { + b := make([]byte, 10) + rand.Read(b) + connID := protocol.ParseConnectionID(b) + server.handlePacket(getValidInitialPacket(t, + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, + randConnID(6), + connID, + )) + select { + case <-recorder.Args(): + case <-time.After(time.Second): + t.Fatal("timeout") + } + } + + server.Close() + // closing closes all handshakeing connections with CONNECTION_REFUSED + for range numConns { + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout") + } + } + + // Accept returns ErrServerClosed after closing + for range 5 { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err := server.Accept(ctx) + require.ErrorIs(t, err, ErrServerClosed) + require.ErrorIs(t, err, net.ErrClosed) + } +} + +func TestServerGetConfigForClientAccept(t *testing.T) { + mockCtrl := gomock.NewController(t) + c := NewMockQUICConn(mockCtrl) + c.EXPECT().run() + c.EXPECT().Context().Return(context.Background()) + c.EXPECT().HandshakeComplete().Return(make(chan struct{})) + recorder := newConnConstructorRecorder(c) + server := newTestServer(t, &serverOpts{ + config: &Config{ + GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { + return &Config{MaxIncomingStreams: 1234}, nil + }, + }, + newConn: recorder.NewConn, + }) + + conn := newUDPConnLocalhost(t) + packet := getValidInitialPacket(t, + conn.LocalAddr(), + protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + ) + c.EXPECT().handlePacket(packet) + + server.handlePacket(packet) + + var args connConstructorArgs + select { + case args = <-recorder.Args(): + require.EqualValues(t, 1234, args.config.MaxIncomingStreams) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + assert.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), args.destConnID) + assert.NotEqual(t, args.origDestConnID, args.srcConnID) + + // shutdown + c.EXPECT().closeWithTransportError(ConnectionRefused) + c.EXPECT().destroy(gomock.Any()).AnyTimes() +} + +func TestServerGetConfigForClientReject(t *testing.T) { + mockCtrl := gomock.NewController(t) + tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) + server := newTestServer(t, &serverOpts{ + tracer: tracer, + config: &Config{ + GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { + return nil, errors.New("rejected") + }, + }, + }) + + conn := newUDPConnLocalhost(t) + srcConnID := randConnID(6) + destConnID := randConnID(8) + p := getValidInitialPacket(t, conn.LocalAddr(), srcConnID, destConnID) + server.handlePacket(p) + argsChan := make(chan sentPacketCallArgs, 1) + mockTracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do( + func(_ net.Addr, hdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { + argsChan <- sentPacketCallArgs{hdr: hdr, frames: frames} + }, + ) + + checkConnectionClose(t, conn, argsChan, destConnID, srcConnID, qerr.ConnectionRefused) +} + +func TestServerPacketHandling(t *testing.T) { + mockCtrl := gomock.NewController(t) + server := newTestServer(t, &serverOpts{}) + + srcConnID := randConnID(6) + destConnID := randConnID(8) + conn := NewMockQUICConn(mockCtrl) + handledPacket := make(chan receivedPacket, 1) + conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { + handledPacket <- p + }) + server.connHandler.Add(destConnID, conn) + + server.handlePacket( + getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, srcConnID, destConnID), + ) + select { + case p := <-handledPacket: + require.Equal(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, p.remoteAddr) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + // shutdown + conn.EXPECT().destroy(gomock.Any()).AnyTimes() +} + +func TestServerReceiveQueue(t *testing.T) { + mockCtrl := gomock.NewController(t) + acceptConn := make(chan struct{}) + tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) + server := newTestServer(t, &serverOpts{ + tracer: tracer, + newConn: func( + _ context.Context, + _ context.CancelCauseFunc, + _ sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ ConnectionIDGenerator, + _ *statelessResetter, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ *logging.ConnectionTracer, + _ utils.Logger, + _ protocol.Version, + ) quicConn { + <-acceptConn + conn := NewMockQUICConn(mockCtrl) + conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) + conn.EXPECT().run().MaxTimes(1) + conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) + conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1) + conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) + return conn + }, + }) + + conn := newUDPConnLocalhost(t) + for range protocol.MaxServerUnprocessedPackets + 1 { + server.handlePacket(getValidInitialPacket(t, conn.LocalAddr(), randConnID(6), randConnID(8))) + } + + done := make(chan struct{}) + mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketTypeNotDetermined, gomock.Any(), logging.PacketDropDOSPrevention).Do( + func(_ net.Addr, _ logging.PacketType, _ logging.ByteCount, _ logging.PacketDropReason) { + close(done) + }, + ) + server.handlePacket(getValidInitialPacket(t, conn.LocalAddr(), randConnID(6), randConnID(8))) + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout") + } + close(acceptConn) +} + +func TestServerAccept(t *testing.T) { + t.Run("without accept early", func(t *testing.T) { + testServerAccept(t, false) + }) + t.Run("with accept early", func(t *testing.T) { + testServerAccept(t, true) + }) +} + +func testServerAccept(t *testing.T, acceptEarly bool) { + mockCtrl := gomock.NewController(t) + ready := make(chan struct{}) + c := NewMockQUICConn(mockCtrl) + c.EXPECT().run() + c.EXPECT().handlePacket(gomock.Any()) + c.EXPECT().Context().Return(context.Background()) + if acceptEarly { + c.EXPECT().earlyConnReady().Return(ready) + } else { + c.EXPECT().HandshakeComplete().Return(ready) + } + recorder := newConnConstructorRecorder(c) + tracer, _ := mocklogging.NewMockTracer(mockCtrl) + server := newTestServer(t, &serverOpts{ + acceptEarly: acceptEarly, + tracer: tracer, + newConn: recorder.NewConn, + }) + + // Accept should respect the context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := server.Accept(ctx) + require.ErrorIs(t, err, context.Canceled) + + // establish a new connection, which then starts handshaking + server.handlePacket(getValidInitialPacket(t, + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, + randConnID(6), + randConnID(8), + )) + + accepted := make(chan error, 1) + go func() { + _, err := server.Accept(context.Background()) + accepted <- err + }() + + select { + case <-accepted: + t.Fatal("server accepted the connection too early") + case <-time.After(scaleDuration(5 * time.Millisecond)): + } + // now complete the handshake + close(ready) + + select { + case err := <-accepted: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timeout") + } +} + +func TestServerAcceptHandshakeFailure(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := NewMockQUICConn(gomock.NewController(t)) + c.EXPECT().run() + c.EXPECT().handlePacket(gomock.Any()) + c.EXPECT().Context().Return(ctx) + c.EXPECT().HandshakeComplete().Return(make(chan struct{})) + recorder := newConnConstructorRecorder(c) + server := newTestServer(t, &serverOpts{newConn: recorder.NewConn}) + + // establish a new connection, which then starts handshaking + server.handlePacket(getValidInitialPacket(t, + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, + randConnID(6), + randConnID(8), + )) + + accepted := make(chan error, 1) + go func() { + _, err := server.Accept(context.Background()) + accepted <- err + }() + + cancel() + select { + case <-accepted: + t.Fatal("server should not have accepted the connection") + case <-time.After(scaleDuration(5 * time.Millisecond)): + } +} + +func TestServerAcceptQueue(t *testing.T) { + mockCtrl := gomock.NewController(t) + var conns []quicConn + var rejectedConn *MockQUICConn + for i := range protocol.MaxAcceptQueueSize + 2 { + conn := NewMockQUICConn(mockCtrl) + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().run() + c := make(chan struct{}) + close(c) + conn.EXPECT().HandshakeComplete().Return(c) + conn.EXPECT().Context().Return(context.Background()) + conns = append(conns, conn) + if i == protocol.MaxAcceptQueueSize { + rejectedConn = conn + continue + } + defer func(conn *MockQUICConn) { + conn.EXPECT().closeWithTransportError(ConnectionRefused).MaxTimes(1) + }(conn) + } + recorder := newConnConstructorRecorder(conns...) + server := newTestServer(t, &serverOpts{newConn: recorder.NewConn}) + + for range protocol.MaxAcceptQueueSize { + b := make([]byte, 16) + rand.Read(b) + connID := protocol.ParseConnectionID(b) + server.handlePacket( + getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(6), connID), + ) + select { + case args := <-recorder.Args(): + require.Equal(t, connID, args.origDestConnID) + case <-time.After(time.Second): + t.Fatal("timeout") + } + } + // wait for the connection to be enqueued + time.Sleep(scaleDuration(10 * time.Millisecond)) + + done := make(chan struct{}) + rejectedConn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { + close(done) + }) + server.handlePacket( + getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(6), randConnID(8)), + ) + select { + case <-recorder.Args(): + case <-time.After(time.Second): + t.Fatal("timeout") + } + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + // accept one connection, freeing up one slot in the accept queue + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err := server.Accept(ctx) + require.NoError(t, err) + + // it's now possible to enqueue a new connection + server.handlePacket( + getValidInitialPacket(t, + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, + randConnID(6), + protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), + ), + ) + select { + case args := <-recorder.Args(): + require.Equal(t, protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), args.origDestConnID) + case <-time.After(time.Second): + t.Fatal("timeout") + } +} + +func TestServer0RTTReordering(t *testing.T) { + mockCtrl := gomock.NewController(t) + tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) + conn := NewMockQUICConn(mockCtrl) + recorder := newConnConstructorRecorder(conn) + server := newTestServer(t, &serverOpts{ + acceptEarly: true, + tracer: tracer, + newConn: recorder.NewConn, + }) + + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + + var zeroRTTPackets []receivedPacket + + for range protocol.Max0RTTQueueLen { + p := getLongHeaderPacket(t, + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, + &wire.ExtendedHeader{ + Header: wire.Header{ + Type: protocol.PacketType0RTT, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: connID, + Length: 100, + Version: protocol.Version1, + }, + PacketNumberLen: protocol.PacketNumberLen4, + }, + make([]byte, 100), + ) + server.handlePacket(p) + zeroRTTPackets = append(zeroRTTPackets, p) + } + + // send one more packet, this one should be dropped + p := getLongHeaderPacket(t, + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, + &wire.ExtendedHeader{ + Header: wire.Header{ + Type: protocol.PacketType0RTT, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: connID, + Length: 100, + Version: protocol.Version1, + }, + PacketNumberLen: protocol.PacketNumberLen4, + }, + make([]byte, 100), + ) + mockTracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + server.handlePacket(p) + + // now receive the Initial + initial := getValidInitialPacket(t, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, randConnID(5), connID) + packets := make(chan receivedPacket, protocol.Max0RTTQueueLen+1) + conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { packets <- p }).AnyTimes() + conn.EXPECT().Context().Return(context.Background()) + conn.EXPECT().earlyConnReady().Return(make(chan struct{})) + conn.EXPECT().run() + server.handlePacket(initial) + + for i := range protocol.Max0RTTQueueLen + 1 { + select { + case p := <-packets: + if i == 0 { + require.Equal(t, initial.data, p.data) + } else { + require.Equal(t, zeroRTTPackets[i-1], p) + } + case <-time.After(time.Second): + t.Fatal("timeout") + } + } + + // shutdown + conn.EXPECT().closeWithTransportError(gomock.Any()).AnyTimes() +} + +func TestServer0RTTQueueing(t *testing.T) { + mockCtrl := gomock.NewController(t) + tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) + server := newTestServer(t, &serverOpts{ + acceptEarly: true, + tracer: tracer, + }) + + firstRcvTime := time.Now() + otherRcvTime := firstRcvTime.Add(protocol.Max0RTTQueueingDuration / 2) + var sizes []protocol.ByteCount + for i := range protocol.Max0RTTQueues { + b := make([]byte, 16) + rand.Read(b) + connID := protocol.ParseConnectionID(b) + size := protocol.ByteCount(500 + i) + p := getLongHeaderPacket(t, + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, + &wire.ExtendedHeader{ + Header: wire.Header{ + Type: protocol.PacketType0RTT, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: connID, + Length: size, + Version: protocol.Version1, + }, + PacketNumberLen: protocol.PacketNumberLen4, + }, + make([]byte, size), + ) + if i == 0 { + p.rcvTime = firstRcvTime + } else { + p.rcvTime = otherRcvTime + } + sizes = append(sizes, p.Size()) + server.handlePacket(p) + } + + // maximum number of 0-RTT queues is reached, further packets are dropped + p := getLongHeaderPacket(t, + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, + &wire.ExtendedHeader{ + Header: wire.Header{ + Type: protocol.PacketType0RTT, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + Length: 123, + Version: protocol.Version1, + }, + PacketNumberLen: protocol.PacketNumberLen4, + }, + make([]byte, 123), + ) + dropped := make(chan struct{}, protocol.Max0RTTQueues) + mockTracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do( + func(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) { dropped <- struct{}{} }, + ) + server.handlePacket(p) + select { + case <-dropped: + case <-time.After(time.Second): + t.Fatal("timeout") + } + require.True(t, mockCtrl.Satisfied()) + + // There's no cleanup Go routine. + // Cleanup is triggered when new packets are received. + // 1. Receive one handshake packet, which triggers the cleanup of the first 0-RTT queue + mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketTypeHandshake, gomock.Any(), gomock.Any()) + mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketType0RTT, sizes[0], gomock.Any()).Do( + func(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) { dropped <- struct{}{} }, + ) + triggerPacket := getLongHeaderPacket(t, + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, + &wire.ExtendedHeader{ + Header: wire.Header{ + Type: protocol.PacketTypeHandshake, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), + Length: 123, + Version: protocol.Version1, + }, + PacketNumberLen: protocol.PacketNumberLen4, + }, + make([]byte, 123), + ) + triggerPacket.rcvTime = firstRcvTime.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) + server.handlePacket(triggerPacket) + select { + case <-dropped: + case <-time.After(time.Second): + t.Fatal("timeout") + } + require.True(t, mockCtrl.Satisfied()) + + // 2. Receive another handshake packet, which triggers the cleanup of the other 0-RTT queues + triggerPacket = getLongHeaderPacket(t, + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}, + &wire.ExtendedHeader{ + Header: wire.Header{ + Type: protocol.PacketTypeHandshake, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), + Length: 124, + Version: protocol.Version1, + }, + PacketNumberLen: protocol.PacketNumberLen4, + }, + make([]byte, 124), + ) + triggerPacket.rcvTime = otherRcvTime.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) + mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketTypeHandshake, gomock.Any(), gomock.Any()) + for i := range sizes[1:] { + mockTracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketType0RTT, sizes[i+1], gomock.Any()).Do( + func(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) { dropped <- struct{}{} }, + ) + } + server.handlePacket(triggerPacket) + + for range protocol.Max0RTTQueues - 1 { + select { + case <-dropped: + case <-time.After(time.Second): + t.Fatal("timeout") + } + } +}