From b493c5d827cda015e603c692168e6acf4103512c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 23 Dec 2024 14:47:44 +0800 Subject: [PATCH] migrate the transport tests away from Ginkgo (#4783) * migrate the transport tests away from Ginkgo * simplify mock net.PacketConn implementation --- quic_suite_test.go | 27 ++ stream_test.go | 13 - transport.go | 4 +- transport_test.go | 792 ++++++++++++++++++++++----------------------- 4 files changed, 410 insertions(+), 426 deletions(-) diff --git a/quic_suite_test.go b/quic_suite_test.go index 954ca60b9..1b4e3b67f 100644 --- a/quic_suite_test.go +++ b/quic_suite_test.go @@ -4,13 +4,18 @@ import ( "bytes" "io" "log" + "net" + "os" "runtime/pprof" + "strconv" "strings" "sync" "testing" + "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) @@ -32,6 +37,28 @@ var _ = BeforeSuite(func() { log.SetOutput(io.Discard) }) +// in the tests for the stream deadlines we set a deadline +// and wait to make an assertion when Read / Write was unblocked +// on the CIs, the timing is a lot less precise, so scale every duration by this factor +func scaleDuration(t time.Duration) time.Duration { + scaleFactor := 1 + if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set + scaleFactor = f + } + if scaleFactor == 0 { + panic("TIMESCALE_FACTOR is 0") + } + return time.Duration(scaleFactor) * t +} + +func newUPDConnLocalhost(t testing.TB) *net.UDPConn { + t.Helper() + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + return conn +} + func areServersRunning() bool { var b bytes.Buffer pprof.Lookup("goroutine").WriteTo(&b, 1) diff --git a/stream_test.go b/stream_test.go index cd7475403..f01293938 100644 --- a/stream_test.go +++ b/stream_test.go @@ -5,7 +5,6 @@ import ( "errors" "io" "os" - "strconv" "time" "github.com/quic-go/quic-go/internal/mocks" @@ -18,18 +17,6 @@ import ( "go.uber.org/mock/gomock" ) -// in the tests for the stream deadlines we set a deadline -// and wait to make an assertion when Read / Write was unblocked -// on the CIs, the timing is a lot less precise, so scale every duration by this factor -func scaleDuration(t time.Duration) time.Duration { - scaleFactor := 1 - if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set - scaleFactor = f - } - Expect(scaleFactor).ToNot(BeZero()) - return time.Duration(scaleFactor) * t -} - var _ = Describe("Stream", func() { const streamID protocol.StreamID = 1337 diff --git a/transport.go b/transport.go index d835ea00d..f9518beba 100644 --- a/transport.go +++ b/transport.go @@ -241,7 +241,9 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.logger = utils.DefaultLogger // TODO: make this configurable t.conn = conn - t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) + if t.handlerMap == nil { // allows mocking the handlerMap in tests + t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) + } t.listening = make(chan struct{}) t.closeQueue = make(chan closePacket, 4) diff --git a/transport_test.go b/transport_test.go index 9a7718c88..33a60b325 100644 --- a/transport_test.go +++ b/transport_test.go @@ -3,441 +3,409 @@ package quic import ( "bytes" "context" - "crypto/rand" "crypto/tls" "errors" "net" + "os" "syscall" + "testing" "time" mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/logging" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) -var _ = Describe("Transport", func() { - type packetToRead struct { - addr net.Addr - data []byte - err error +type mockPacketConn struct { + localAddr net.Addr + readErrs chan error +} + +func (c *mockPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + err, ok := <-c.readErrs + if !ok { + return 0, nil, net.ErrClosed } + return 0, nil, err +} - getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { - b, err := (&wire.ExtendedHeader{ - Header: wire.Header{ - Type: t, - DestConnectionID: connID, - Length: length, - Version: protocol.Version1, - }, - PacketNumberLen: protocol.PacketNumberLen2, - }).Append(nil, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - return b - } +func (c *mockPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { panic("implement me") } +func (c *mockPacketConn) LocalAddr() net.Addr { return c.localAddr } +func (c *mockPacketConn) Close() error { close(c.readErrs); return nil } +func (c *mockPacketConn) SetDeadline(t time.Time) error { return nil } +func (c *mockPacketConn) SetReadDeadline(t time.Time) error { return nil } +func (c *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil } - getPacket := func(connID protocol.ConnectionID) []byte { - return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2) - } +type mockPacketHandler struct { + packets chan<- receivedPacket + destruction chan<- error +} - newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn { - conn := NewMockPacketConn(mockCtrl) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() - conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { - p, ok := <-packetChan - if !ok { - return 0, nil, errors.New("closed") - } - return copy(b, p.data), p.addr, p.err - }).AnyTimes() - // for shutdown - conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes() - return conn - } +func (h *mockPacketHandler) handlePacket(p receivedPacket) { h.packets <- p } +func (h *mockPacketHandler) destroy(err error) { h.destruction <- err } +func (h *mockPacketHandler) closeWithTransportError(code qerr.TransportErrorCode) {} - It("handles packets for different packet handlers on the same packet conn", func() { - packetChan := make(chan packetToRead) - tr := &Transport{Conn: newMockPacketConn(packetChan)} - tr.init(true) - phm := NewMockPacketHandlerManager(mockCtrl) - tr.handlerMap = phm - connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) +func getPacket(t *testing.T, connID protocol.ConnectionID) []byte { + return getPacketWithPacketType(t, connID, protocol.PacketTypeHandshake, 2) +} - handled := make(chan struct{}, 2) - phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { - h := NewMockPacketHandler(mockCtrl) - h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { - defer GinkgoRecover() - connID, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(connID1)) - handled <- struct{}{} - }) - return h, true - }) - phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { - h := NewMockPacketHandler(mockCtrl) - h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { - defer GinkgoRecover() - connID, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(connID2)) - handled <- struct{}{} - }) - return h, true - }) - - packetChan <- packetToRead{data: getPacket(connID1)} - packetChan <- packetToRead{data: getPacket(connID2)} - - Eventually(handled).Should(Receive()) - Eventually(handled).Should(Receive()) - - // shutdown - phm.EXPECT().Close(gomock.Any()) - close(packetChan) - tr.Close() - }) - - It("closes listeners", func() { - packetChan := make(chan packetToRead) - tr := &Transport{Conn: newMockPacketConn(packetChan)} - defer tr.Close() - ln, err := tr.Listen(&tls.Config{}, nil) - Expect(err).ToNot(HaveOccurred()) - phm := NewMockPacketHandlerManager(mockCtrl) - tr.handlerMap = phm - - Expect(ln.Close()).To(Succeed()) - - // shutdown - phm.EXPECT().Close(gomock.Any()) - close(packetChan) - tr.Close() - }) - - It("closes transport concurrently with listener", func() { - // try 10 times to trigger race conditions - for i := 0; i < 10; i++ { - packetChan := make(chan packetToRead) - tr := &Transport{Conn: newMockPacketConn(packetChan)} - ln, err := tr.Listen(&tls.Config{}, nil) - Expect(err).ToNot(HaveOccurred()) - ch := make(chan bool) - // Close transport and listener concurrently. - go func() { - ch <- true - Expect(ln.Close()).To(Succeed()) - ch <- true - }() - <-ch - close(packetChan) - Expect(tr.Close()).To(Succeed()) - <-ch - } - }) - - It("drops unparseable QUIC packets", func() { - addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} - packetChan := make(chan packetToRead) - t, tracer := mocklogging.NewMockTracer(mockCtrl) - tr := &Transport{ - Conn: newMockPacketConn(packetChan), - ConnectionIDLength: 10, - Tracer: t, - } - tr.init(true) - dropped := make(chan struct{}) - tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }) - packetChan <- packetToRead{ - addr: addr, - data: []byte{0x40 /* set the QUIC bit */, 1, 2, 3}, - } - Eventually(dropped).Should(BeClosed()) - - // shutdown - tracer.EXPECT().Close() - close(packetChan) - tr.Close() - }) - - It("closes when reading from the conn fails", func() { - packetChan := make(chan packetToRead) - tr := Transport{Conn: newMockPacketConn(packetChan)} - defer tr.Close() - phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(true) - tr.handlerMap = phm - - done := make(chan struct{}) - phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) }) - packetChan <- packetToRead{err: errors.New("read failed")} - Eventually(done).Should(BeClosed()) - - // shutdown - close(packetChan) - tr.Close() - }) - - It("continues listening after temporary errors", func() { - packetChan := make(chan packetToRead) - tr := Transport{Conn: newMockPacketConn(packetChan)} - defer tr.Close() - phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(true) - tr.handlerMap = phm - - tempErr := deadlineError{} - Expect(tempErr.Temporary()).To(BeTrue()) - packetChan <- packetToRead{err: tempErr} - // don't expect any calls to phm.Close - time.Sleep(50 * time.Millisecond) - - // shutdown - phm.EXPECT().Close(gomock.Any()) - close(packetChan) - tr.Close() - }) - - It("handles short header packets resets", func() { - connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) - packetChan := make(chan packetToRead) - tr := Transport{ - Conn: newMockPacketConn(packetChan), - ConnectionIDLength: connID.Len(), - } - tr.init(true) - defer tr.Close() - phm := NewMockPacketHandlerManager(mockCtrl) - tr.handlerMap = phm - - var token protocol.StatelessResetToken - rand.Read(token[:]) - - var b []byte - b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) - Expect(err).ToNot(HaveOccurred()) - b = append(b, token[:]...) - conn := NewMockPacketHandler(mockCtrl) - gomock.InOrder( - phm.EXPECT().Get(connID).Return(conn, true), - conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { - Expect(p.data).To(Equal(b)) - Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second)) - }), - ) - packetChan <- packetToRead{data: b} - - // shutdown - phm.EXPECT().Close(gomock.Any()) - close(packetChan) - tr.Close() - }) - - It("handles stateless resets", func() { - connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) - packetChan := make(chan packetToRead) - tr := Transport{ - Conn: newMockPacketConn(packetChan), - ConnectionIDLength: connID.Len(), - } - tr.init(true) - defer tr.Close() - phm := NewMockPacketHandlerManager(mockCtrl) - tr.handlerMap = phm - - var token protocol.StatelessResetToken - rand.Read(token[:]) - - var b []byte - b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) - Expect(err).ToNot(HaveOccurred()) - b = append(b, token[:]...) - conn := NewMockPacketHandler(mockCtrl) - destroyed := make(chan struct{}) - gomock.InOrder( - phm.EXPECT().Get(connID), - phm.EXPECT().GetByResetToken(token).Return(conn, true), - conn.EXPECT().destroy(gomock.Any()).Do(func(err error) { - Expect(err).To(MatchError(&StatelessResetError{})) - close(destroyed) - }), - ) - packetChan <- packetToRead{data: b} - Eventually(destroyed).Should(BeClosed()) - - // shutdown - phm.EXPECT().Close(gomock.Any()) - close(packetChan) - tr.Close() - }) - - It("sends stateless resets", func() { - connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) - packetChan := make(chan packetToRead) - conn := newMockPacketConn(packetChan) - tr := Transport{ - Conn: conn, - StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, - ConnectionIDLength: connID.Len(), - } - tr.init(true) - defer tr.Close() - phm := NewMockPacketHandlerManager(mockCtrl) - tr.handlerMap = phm - - var b []byte - b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) - Expect(err).ToNot(HaveOccurred()) - b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...) - - var token protocol.StatelessResetToken - rand.Read(token[:]) - written := make(chan struct{}) - gomock.InOrder( - phm.EXPECT().Get(connID), - phm.EXPECT().GetByResetToken(gomock.Any()), - phm.EXPECT().GetStatelessResetToken(connID).Return(token), - conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) (int, error) { - defer close(written) - Expect(bytes.Contains(b, token[:])).To(BeTrue()) - return len(b), nil - }), - ) - packetChan <- packetToRead{data: b} - Eventually(written).Should(BeClosed()) - - // shutdown - phm.EXPECT().Close(gomock.Any()) - close(packetChan) - tr.Close() - }) - - It("closes uninitialized Transport and closes underlying PacketConn", func() { - packetChan := make(chan packetToRead) - pconn := newMockPacketConn(packetChan) - - tr := &Transport{ - Conn: pconn, - createdConn: true, // owns pconn - } - // NO init - - // shutdown - close(packetChan) - pconn.EXPECT().Close() - Expect(tr.Close()).To(Succeed()) - }) - - It("doesn't add the PacketConn to the multiplexer if (*Transport).init fails", func() { - packetChan := make(chan packetToRead) - pconn := newMockPacketConn(packetChan) - syscallconn := &mockSyscallConn{pconn} - - tr := &Transport{ - Conn: syscallconn, - } - - err := tr.init(false) - Expect(err).To(HaveOccurred()) - conns := getMultiplexer().(*connMultiplexer).conns - Expect(len(conns)).To(BeZero()) - }) - - It("allows receiving non-QUIC packets", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} - packetChan := make(chan packetToRead) - tr := &Transport{ - Conn: newMockPacketConn(packetChan), - ConnectionIDLength: 10, - } - tr.init(true) - receivedPacketChan := make(chan []byte) - go func() { - defer GinkgoRecover() - b := make([]byte, 100) - n, addr, err := tr.ReadNonQUICPacket(context.Background(), b) - Expect(err).ToNot(HaveOccurred()) - Expect(addr).To(Equal(remoteAddr)) - receivedPacketChan <- b[:n] - }() - // Receiving of non-QUIC packets is enabled when ReadNonQUICPacket is called. - // Give the Go routine some time to spin up. - time.Sleep(scaleDuration(50 * time.Millisecond)) - packetChan <- packetToRead{ - addr: remoteAddr, - data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, - } - - Eventually(receivedPacketChan).Should(Receive(Equal([]byte{0, 1, 2, 3}))) - - // shutdown - close(packetChan) - tr.Close() - }) - - It("drops non-QUIC packet if the application doesn't process them quickly enough", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} - packetChan := make(chan packetToRead) - t, tracer := mocklogging.NewMockTracer(mockCtrl) - tr := &Transport{ - Conn: newMockPacketConn(packetChan), - ConnectionIDLength: 10, - Tracer: t, - } - tr.init(true) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - _, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 10)) - Expect(err).To(MatchError(context.Canceled)) - - for i := 0; i < maxQueuedNonQUICPackets; i++ { - packetChan <- packetToRead{ - addr: remoteAddr, - data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, - } - } - - done := make(chan struct{}) - tracer.EXPECT().DroppedPacket(remoteAddr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { - close(done) - }) - packetChan <- packetToRead{ - addr: remoteAddr, - data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, - } - Eventually(done).Should(BeClosed()) - - // shutdown - tracer.EXPECT().Close() - close(packetChan) - tr.Close() - }) - - remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234} - DescribeTable("setting the tls.Config.ServerName", - func(expected string, conf *tls.Config, addr net.Addr, host string) { - setTLSConfigServerName(conf, addr, host) - Expect(conf.ServerName).To(Equal(expected)) +func getPacketWithPacketType(t *testing.T, connID protocol.ConnectionID, typ protocol.PacketType, length protocol.ByteCount) []byte { + t.Helper() + b, err := (&wire.ExtendedHeader{ + Header: wire.Header{ + Type: typ, + DestConnectionID: connID, + Length: length, + Version: protocol.Version1, }, - Entry("uses the value from the config", "foo.bar", &tls.Config{ServerName: "foo.bar"}, remoteAddr, "baz.foo"), - Entry("uses the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org"), - Entry("removes the port from the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org:1234"), - Entry("uses the IP", "1.3.5.7", &tls.Config{}, remoteAddr, ""), + PacketNumberLen: protocol.PacketNumberLen2, + }).Append(nil, protocol.Version1) + require.NoError(t, err) + return b +} + +func TestTransportPacketHandling(t *testing.T) { + mockCtrl := gomock.NewController(t) + phm := NewMockPacketHandlerManager(mockCtrl) + + tr := &Transport{ + Conn: newUPDConnLocalhost(t), + handlerMap: phm, + } + tr.init(true) + defer func() { + phm.EXPECT().Close(gomock.Any()) + tr.Close() + }() + + connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) + + connChan1 := make(chan receivedPacket, 1) + conn1 := &mockPacketHandler{packets: connChan1} + phm.EXPECT().Get(connID1).Return(conn1, true) + connChan2 := make(chan receivedPacket, 1) + conn2 := &mockPacketHandler{packets: connChan2} + phm.EXPECT().Get(connID2).Return(conn2, true) + + conn := newUPDConnLocalhost(t) + _, err := conn.WriteTo(getPacket(t, connID1), tr.Conn.LocalAddr()) + require.NoError(t, err) + _, err = conn.WriteTo(getPacket(t, connID2), tr.Conn.LocalAddr()) + require.NoError(t, err) + + select { + case p := <-connChan1: + require.Equal(t, conn.LocalAddr(), p.remoteAddr) + connID, err := wire.ParseConnectionID(p.data, 0) + require.NoError(t, err) + require.Equal(t, connID1, connID) + case <-time.After(time.Second): + t.Fatal("timeout") + } + select { + case p := <-connChan2: + require.Equal(t, conn.LocalAddr(), p.remoteAddr) + connID, err := wire.ParseConnectionID(p.data, 0) + require.NoError(t, err) + require.Equal(t, connID2, connID) + case <-time.After(time.Second): + t.Fatal("timeout") + } +} + +func TestTransportAndListenerConcurrentClose(t *testing.T) { + // try 10 times to trigger race conditions + for i := 0; i < 10; i++ { + tr := &Transport{Conn: newUPDConnLocalhost(t)} + ln, err := tr.Listen(&tls.Config{}, nil) + require.NoError(t, err) + // close transport and listener concurrently + lnErrChan := make(chan error, 1) + go func() { lnErrChan <- ln.Close() }() + require.NoError(t, tr.Close()) + select { + case err := <-lnErrChan: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timeout") + } + } +} + +func TestTransportErrFromConn(t *testing.T) { + mockCtrl := gomock.NewController(t) + phm := NewMockPacketHandlerManager(mockCtrl) + readErrChan := make(chan error, 2) + conn := &mockPacketConn{readErrs: readErrChan, localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}} + tr := Transport{Conn: conn, handlerMap: phm} + defer tr.Close() + + tr.init(true) + tr.handlerMap = phm + + // temporary errors don't lead to a shutdown... + var tempErr deadlineError + require.True(t, tempErr.Temporary()) + readErrChan <- tempErr + // don't expect any calls to phm.Close + time.Sleep(scaleDuration(20 * time.Millisecond)) + + // ...but non-temporary errors do + done := make(chan struct{}) + phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) }) + readErrChan <- errors.New("read failed") + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + // TODO(#4778): test that it's not possible to listen after the transport is closed +} + +func TestTransportStatelessResetReceiving(t *testing.T) { + mockCtrl := gomock.NewController(t) + phm := NewMockPacketHandlerManager(mockCtrl) + tr := &Transport{ + Conn: newUPDConnLocalhost(t), + ConnectionIDLength: 4, + handlerMap: phm, + } + tr.init(true) + defer func() { + phm.EXPECT().Close(gomock.Any()) + tr.Close() + }() + + // TODO(#4781): test that packets too short to be stateless resets are dropped + + connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12}) + // now send a packet with a connection ID that doesn't exist + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + b, err := wire.AppendShortHeader(nil, connID, 1337, 2, protocol.KeyPhaseOne) + require.NoError(t, err) + b = append(b, token[:]...) + + destroyChan := make(chan error, 1) + conn1 := &mockPacketHandler{destruction: destroyChan} + gomock.InOrder( + phm.EXPECT().Get(connID), // no handler for this connection ID + phm.EXPECT().GetByResetToken(token).Return(conn1, true), ) -}) -type mockSyscallConn struct { - net.PacketConn + conn := newUPDConnLocalhost(t) + _, err = conn.WriteTo(b, tr.Conn.LocalAddr()) + require.NoError(t, err) + + select { + case err := <-destroyChan: + require.Error(t, err) + case <-time.After(time.Second): + t.Fatal("timeout") + } } -func (c *mockSyscallConn) SyscallConn() (syscall.RawConn, error) { - return nil, errors.New("mocked") +func TestTransportStatelessResetSending(t *testing.T) { + mockCtrl := gomock.NewController(t) + phm := NewMockPacketHandlerManager(mockCtrl) + tr := &Transport{ + Conn: newUPDConnLocalhost(t), + ConnectionIDLength: 4, + StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, + handlerMap: phm, + } + tr.init(true) + defer func() { + phm.EXPECT().Close(gomock.Any()) + tr.Close() + }() + + connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12}) + phm.EXPECT().Get(connID).Times(2) // no handler for this connection ID + phm.EXPECT().GetByResetToken(gomock.Any()).Times(2) + + // now send a packet with a connection ID that doesn't exist + b, err := wire.AppendShortHeader(nil, connID, 1337, 2, protocol.KeyPhaseOne) + require.NoError(t, err) + + conn := newUPDConnLocalhost(t) + + // no stateless reset sent for packets smaller than MinStatelessResetSize + _, err = conn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b))...), tr.Conn.LocalAddr()) + require.NoError(t, err) + conn.SetReadDeadline(time.Now().Add(scaleDuration(10 * time.Millisecond))) + _, _, err = conn.ReadFrom(make([]byte, 1024)) + require.Error(t, err) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + + // no stateless reset sent for packets smaller than MinStatelessResetSize + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + phm.EXPECT().GetStatelessResetToken(connID).Return(token) + _, err = conn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr()) + require.NoError(t, err) + conn.SetReadDeadline(time.Now().Add(time.Second)) + p := make([]byte, 1024) + n, addr, err := conn.ReadFrom(p) + require.NoError(t, err) + require.Equal(t, addr, tr.Conn.LocalAddr()) + require.Contains(t, string(p[:n]), string(token[:])) +} + +func TestTransportDropsUnparseableQUICPackets(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockTracer, tracer := mocklogging.NewMockTracer(mockCtrl) + tr := &Transport{ + Conn: newUPDConnLocalhost(t), + ConnectionIDLength: 10, + Tracer: mockTracer, + } + require.NoError(t, tr.init(true)) + defer func() { + tracer.EXPECT().Close() + tr.Close() + }() + + conn := newUPDConnLocalhost(t) + + dropped := make(chan struct{}) + tracer.EXPECT().DroppedPacket(conn.LocalAddr(), logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do( + func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }, + ) + _, err := conn.WriteTo([]byte{0x40 /* set the QUIC bit */, 1, 2, 3}, tr.Conn.LocalAddr()) + require.NoError(t, err) + select { + case <-dropped: + case <-time.After(time.Second): + t.Fatal("timeout waiting for packet to be dropped") + } +} + +func TestTransportSingleListener(t *testing.T) { + tr := &Transport{Conn: newUPDConnLocalhost(t)} + require.NoError(t, tr.init(true)) + defer tr.Close() + + // TODO(#4779): test that packets are dropped if no listener is set + + ln, err := tr.Listen(&tls.Config{}, nil) + require.NoError(t, err) + + // only a single listener can be set + _, err = tr.Listen(&tls.Config{}, nil) + require.Error(t, err) + require.ErrorIs(t, err, errListenerAlreadySet) + + require.NoError(t, ln.Close()) + // now it's possible to add a new listener + ln, err = tr.Listen(&tls.Config{}, nil) + require.NoError(t, err) + defer ln.Close() +} + +func TestTransportNonQUICPackets(t *testing.T) { + tr := &Transport{Conn: newUPDConnLocalhost(t)} + defer tr.Close() + + ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(5*time.Millisecond)) + defer cancel() + _, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 1024)) + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + + conn := newUPDConnLocalhost(t) + data := []byte{0 /* don't set the QUIC bit */, 1, 2, 3} + _, err = conn.WriteTo(data, tr.Conn.LocalAddr()) + require.NoError(t, err) + _, err = conn.WriteTo(data, tr.Conn.LocalAddr()) + require.NoError(t, err) + + ctx, cancel = context.WithTimeout(context.Background(), scaleDuration(time.Second)) + defer cancel() + b := make([]byte, 1024) + n, addr, err := tr.ReadNonQUICPacket(ctx, b) + require.NoError(t, err) + require.Equal(t, data, b[:n]) + require.Equal(t, addr, conn.LocalAddr()) + + // now send a lot of packets without reading them + for i := range 2 * maxQueuedNonQUICPackets { + data := append([]byte{0 /* don't set the QUIC bit */, uint8(i)}, bytes.Repeat([]byte{uint8(i)}, 1000)...) + _, err = conn.WriteTo(data, tr.Conn.LocalAddr()) + require.NoError(t, err) + } + time.Sleep(scaleDuration(10 * time.Millisecond)) + + var received int + for { + ctx, cancel = context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond)) + defer cancel() + _, _, err := tr.ReadNonQUICPacket(ctx, b) + if errors.Is(err, context.DeadlineExceeded) { + break + } + require.NoError(t, err) + received++ + } + require.Equal(t, received, maxQueuedNonQUICPackets) +} + +type faultySyscallConn struct{ net.PacketConn } + +func (c *faultySyscallConn) SyscallConn() (syscall.RawConn, error) { return nil, errors.New("mocked") } + +func TestTransportFaultySyscallConn(t *testing.T) { + syscallconn := &faultySyscallConn{PacketConn: newUPDConnLocalhost(t)} + + tr := &Transport{Conn: syscallconn} + _, err := tr.Listen(&tls.Config{}, nil) + require.Error(t, err) + require.ErrorContains(t, err, "mocked") + + conns := getMultiplexer().(*connMultiplexer).conns + require.Empty(t, conns) +} + +func TestTransportSetTLSConfigServerName(t *testing.T) { + for _, tt := range []struct { + name string + expected string + conf *tls.Config + host string + }{ + { + name: "uses the value from the config", + expected: "foo.bar", + conf: &tls.Config{ServerName: "foo.bar"}, + host: "baz.foo", + }, + { + name: "uses the hostname", + expected: "golang.org", + conf: &tls.Config{}, + host: "golang.org", + }, + { + name: "removes the port from the hostname", + expected: "golang.org", + conf: &tls.Config{}, + host: "golang.org:1234", + }, + { + name: "uses the IP", + expected: "1.3.5.7", + conf: &tls.Config{}, + host: "", + }, + } { + t.Run(tt.name, func(t *testing.T) { + setTLSConfigServerName(tt.conf, &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234}, tt.host) + require.Equal(t, tt.expected, tt.conf.ServerName) + }) + } }