From ae5a8bd35ce35470b4609caf2dedc9b00dfc7cee Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 31 Mar 2023 19:35:03 +0900 Subject: [PATCH 1/8] move the QUIC version check to the config validation --- client.go | 10 ---------- client_test.go | 5 +---- config.go | 7 +++++++ server.go | 5 ----- server_test.go | 2 +- 5 files changed, 9 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index 62cad5807..ad80d4f29 100644 --- a/client.go +++ b/client.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "errors" - "fmt" "net" "github.com/quic-go/quic-go/internal/protocol" @@ -136,15 +135,6 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon tlsConf = tlsConf.Clone() } - // check that all versions are actually supported - if config != nil { - for _, v := range config.Versions { - if !protocol.IsValidVersion(v) { - return nil, fmt.Errorf("%s is not a valid QUIC version", v) - } - } - } - srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID() if err != nil { return nil, err diff --git a/client_test.go b/client_test.go index 23122e629..dbd03cf3c 100644 --- a/client_test.go +++ b/client_test.go @@ -366,12 +366,9 @@ var _ = Describe("Client", func() { }) It("errors when the Config contains an invalid version", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - version := protocol.VersionNumber(0x1234) _, err := Dial(context.Background(), packetConn, nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) - Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) + Expect(err).To(MatchError("invalid QUIC version: 0x1234")) }) It("disables bidirectional streams", func() { diff --git a/config.go b/config.go index 3ead9b7a1..b513a46b3 100644 --- a/config.go +++ b/config.go @@ -2,6 +2,7 @@ package quic import ( "errors" + "fmt" "net" "time" @@ -29,6 +30,12 @@ func validateConfig(config *Config) error { if config.MaxIncomingUniStreams > 1<<60 { return errors.New("invalid value for Config.MaxIncomingUniStreams") } + // check that all QUIC versions are actually supported + for _, v := range config.Versions { + if !protocol.IsValidVersion(v) { + return fmt.Errorf("invalid QUIC version: %s", v) + } + } return nil } diff --git a/server.go b/server.go index 3543e5464..d5bb19e6d 100644 --- a/server.go +++ b/server.go @@ -234,11 +234,6 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl return nil, err } config = populateServerConfig(config) - for _, v := range config.Versions { - if !protocol.IsValidVersion(v) { - return nil, fmt.Errorf("%s is not a valid QUIC version", v) - } - } connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) if err != nil { diff --git a/server_test.go b/server_test.go index e9f751e60..4108c6989 100644 --- a/server_test.go +++ b/server_test.go @@ -114,7 +114,7 @@ var _ = Describe("Server", func() { It("errors when the Config contains an invalid version", func() { version := protocol.VersionNumber(0x1234) _, err := Listen(nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) - Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) + Expect(err).To(MatchError("invalid QUIC version: 0x1234")) }) It("fills in default values if options are not set in the Config", func() { From 8189e75be6121fdc31dc1d6085f17015e9154667 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 6 Apr 2023 18:02:51 +0800 Subject: [PATCH 2/8] implement the Transport --- client.go | 158 +++-- client_test.go | 236 ++------ integrationtests/self/cancelation_test.go | 1 + integrationtests/self/close_test.go | 2 + integrationtests/self/conn_id_test.go | 12 +- integrationtests/self/datagram_test.go | 36 +- integrationtests/self/early_data_test.go | 1 + integrationtests/self/go119_test.go | 7 +- integrationtests/self/go120_test.go | 2 +- integrationtests/self/handshake_rtt_test.go | 11 +- integrationtests/self/handshake_test.go | 8 +- integrationtests/self/http_test.go | 79 ++- integrationtests/self/multiplex_test.go | 33 +- integrationtests/self/packetization_test.go | 3 +- integrationtests/self/self_suite_test.go | 8 + integrationtests/self/stateless_reset_test.go | 28 +- integrationtests/self/stream_test.go | 3 + integrationtests/self/timeout_test.go | 1 + integrationtests/self/uni_stream_test.go | 4 + mock_multiplexer_test.go | 65 -- mock_packet_handler_manager_test.go | 53 +- mockgen.go | 3 - multiplexer.go | 69 +-- multiplexer_test.go | 69 +-- packet_handler_map.go | 266 ++------- packet_handler_map_test.go | 555 ++++++------------ quic_suite_test.go | 17 + server.go | 92 +-- server_test.go | 40 +- transport.go | 410 +++++++++++++ transport_test.go | 287 +++++++++ 31 files changed, 1309 insertions(+), 1250 deletions(-) delete mode 100644 mock_multiplexer_test.go create mode 100644 transport.go create mode 100644 transport_test.go diff --git a/client.go b/client.go index ad80d4f29..c8ea0641a 100644 --- a/client.go +++ b/client.go @@ -20,6 +20,7 @@ type client struct { use0RTT bool packetHandlers packetHandlerManager + onClose func() tlsConf *tls.Config config *Config @@ -45,32 +46,58 @@ var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial // DialAddr establishes a new QUIC connection to a server. // It uses a new UDP connection and closes this connection when the QUIC connection is closed. -// The hostname for SNI is taken from the given address. -func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (Connection, error) { - return dialAddrContext(ctx, addr, tlsConf, config, false) -} - -// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. -// It uses a new UDP connection and closes this connection when the QUIC connection is closed. -func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (EarlyConnection, error) { - conn, err := dialAddrContext(ctx, addr, tlsConf, config, true) - if err != nil { - return nil, err - } - utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection") - return conn, nil -} - -func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, config *Config, use0RTT bool) (quicConn, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } +func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) { udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { return nil, err } - return dialContext(ctx, udpConn, udpAddr, tlsConf, config, use0RTT, true) + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + dl, err := setupTransport(udpConn, tlsConf, true) + if err != nil { + return nil, err + } + return dl.Dial(ctx, udpAddr, tlsConf, conf) +} + +// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. +// It uses a new UDP connection and closes this connection when the QUIC connection is closed. +func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + return nil, err + } + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + dl, err := setupTransport(udpConn, tlsConf, true) + if err != nil { + return nil, err + } + conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf) + if err != nil { + dl.Close() + return nil, err + } + return conn, nil +} + +// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context. +// See DialEarly for details. +func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + dl, err := setupTransport(c, tlsConf, false) + if err != nil { + return nil, err + } + conn, err := dl.DialEarly(ctx, addr, tlsConf, conf) + if err != nil { + dl.Close() + return nil, err + } + return conn, nil } // Dial establishes a new QUIC connection to a server using a net.PacketConn. If @@ -78,34 +105,43 @@ func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, conf // does), ECN and packet info support will be enabled. In this case, ReadMsgUDP // and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write // packets. -// The same PacketConn can be used for multiple calls to Dial and Listen. -// QUIC connection IDs are used for demultiplexing the different connections. // The tls.Config must define an application protocol (using NextProtos). -func Dial(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (Connection, error) { - return dialContext(ctx, pconn, addr, tlsConf, config, false, false) -} - -// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn. -// The same PacketConn can be used for multiple calls to Dial and Listen, -// QUIC connection IDs are used for demultiplexing the different connections. -// The tls.Config must define an application protocol (using NextProtos). -func DialEarly(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (EarlyConnection, error) { - return dialContext(ctx, pconn, addr, tlsConf, config, true, false) -} - -func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config, use0RTT bool, createdPacketConn bool) (quicConn, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(config); err != nil { - return nil, err - } - config = populateClientConfig(config, createdPacketConn) - packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) +func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { + dl, err := setupTransport(c, tlsConf, false) if err != nil { return nil, err } - c, err := newClient(pconn, addr, config, tlsConf, use0RTT, createdPacketConn) + conn, err := dl.Dial(ctx, addr, tlsConf, conf) + if err != nil { + dl.Close() + return nil, err + } + return conn, nil +} + +func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + return &Transport{ + Conn: c, + createdConn: createdPacketConn, + isSingleUse: true, + }, nil +} + +func dial( + ctx context.Context, + conn net.PacketConn, + packetHandlers packetHandlerManager, + addr net.Addr, + tlsConf *tls.Config, + config *Config, + onClose func(), + use0RTT bool, + createdPacketConn bool, +) (quicConn, error) { + c, err := newClient(conn, addr, config, tlsConf, onClose, use0RTT, createdPacketConn) if err != nil { return nil, err } @@ -128,7 +164,7 @@ func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsCo return c.conn, nil } -func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, use0RTT bool, createdPacketConn bool) (*client, error) { +func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, onClose func(), use0RTT, createdPacketConn bool) (*client, error) { if tlsConf == nil { tlsConf = &tls.Config{} } else { @@ -149,6 +185,7 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon sconn: newSendPconn(pconn, remoteAddr), createdPacketConn: createdPacketConn, use0RTT: use0RTT, + onClose: onClose, tlsConf: tlsConf, config: config, version: config.Versions[0], @@ -179,13 +216,18 @@ func (c *client) dial(ctx context.Context) error { c.packetHandlers.Add(c.srcConnID, c.conn) errorChan := make(chan error, 1) + recreateChan := make(chan errCloseForRecreating) go func() { - err := c.conn.run() // returns as soon as the connection is closed - - if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn { - c.packetHandlers.Destroy() + err := c.conn.run() + var recreateErr *errCloseForRecreating + if errors.As(err, &recreateErr) { + recreateChan <- *recreateErr + return } - errorChan <- err + if c.onClose != nil { + c.onClose() + } + errorChan <- err // returns as soon as the connection is closed }() // only set when we're using 0-RTT @@ -200,14 +242,12 @@ func (c *client) dial(ctx context.Context) error { c.conn.shutdown() return ctx.Err() case err := <-errorChan: - var recreateErr *errCloseForRecreating - if errors.As(err, &recreateErr) { - c.initialPacketNumber = recreateErr.nextPacketNumber - c.version = recreateErr.nextVersion - c.hasNegotiatedVersion = true - return c.dial(ctx) - } return err + case recreateErr := <-recreateChan: + c.initialPacketNumber = recreateErr.nextPacketNumber + c.version = recreateErr.nextVersion + c.hasNegotiatedVersion = true + return c.dial(ctx) case <-earlyConnChan: // ready to send 0-RTT data return nil diff --git a/client_test.go b/client_test.go index dbd03cf3c..ec8b83b8b 100644 --- a/client_test.go +++ b/client_test.go @@ -18,13 +18,17 @@ import ( . "github.com/onsi/gomega" ) +type nullMultiplexer struct{} + +func (n nullMultiplexer) AddConn(indexableConn) {} +func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil } + var _ = Describe("Client", func() { var ( cl *client packetConn *MockPacketConn addr net.Addr connID protocol.ConnectionID - mockMultiplexer *MockMultiplexer origMultiplexer multiplexer tlsConf *tls.Config tracer *mocklogging.MockConnectionTracer @@ -53,6 +57,7 @@ var _ = Describe("Client", func() { originalClientConnConstructor = newClientConnection tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tr := mocklogging.NewMockTracer(mockCtrl) + tr.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}} Eventually(areConnsRunning).Should(BeFalse()) @@ -68,10 +73,9 @@ var _ = Describe("Client", func() { logger: utils.DefaultLogger, } getMultiplexer() // make the sync.Once execute - // replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer - mockMultiplexer = NewMockMultiplexer(mockCtrl) + // replace the clientMuxer. getMultiplexer will now return the nullMultiplexer origMultiplexer = connMuxer - connMuxer = mockMultiplexer + connMuxer = &nullMultiplexer{} }) AfterEach(func() { @@ -100,48 +104,14 @@ var _ = Describe("Client", func() { generateConnectionIDForInitial = origGenerateConnectionIDForInitial }) - It("resolves the address", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - manager.EXPECT().Destroy() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - remoteAddrChan := make(chan string, 1) - newClientConnection = func( - sconn sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - remoteAddrChan <- sconn.RemoteAddr().String() - conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run() - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) - return conn - } - _, err := DialAddr(context.Background(), "localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond}) - Expect(err).ToNot(HaveOccurred()) - Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890"))) - }) - It("returns after the handshake is complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) run := make(chan struct{}) newClientConnection = func( _ sendConn, - runner connRunner, + _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, @@ -162,18 +132,17 @@ var _ = Describe("Client", func() { conn.EXPECT().HandshakeComplete().Return(c) return conn } - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - s, err := Dial(context.Background(), packetConn, addr, tlsConf, config) + cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, false, false) Expect(err).ToNot(HaveOccurred()) - Expect(s).ToNot(BeNil()) + cl.packetHandlers = manager + Expect(cl).ToNot(BeNil()) + Expect(cl.dial(context.Background())).To(Succeed()) Eventually(run).Should(BeClosed()) }) It("returns early connections", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - readyChan := make(chan struct{}) done := make(chan struct{}) newClientConnection = func( @@ -193,29 +162,23 @@ var _ = Describe("Client", func() { ) quicConn { Expect(enable0RTT).To(BeTrue()) conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run().Do(func() { <-done }) + conn.EXPECT().run().Do(func() { close(done) }) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) conn.EXPECT().earlyConnReady().Return(readyChan) return conn } - go func() { - defer GinkgoRecover() - defer close(done) - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - s, err := DialEarly(context.Background(), packetConn, addr, tlsConf, config) - Expect(err).ToNot(HaveOccurred()) - Expect(s).ToNot(BeNil()) - }() - Consistently(done).ShouldNot(BeClosed()) - close(readyChan) + cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, true, false) + Expect(err).ToNot(HaveOccurred()) + cl.packetHandlers = manager + Expect(cl).ToNot(BeNil()) + Expect(cl.dial(context.Background())).To(Succeed()) Eventually(done).Should(BeClosed()) }) It("returns an error that occurs while waiting for the handshake to complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) testErr := errors.New("early handshake error") newClientConnection = func( @@ -236,108 +199,16 @@ var _ = Describe("Client", func() { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().run().Return(testErr) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) + conn.EXPECT().earlyConnReady().Return(make(chan struct{})) return conn } - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - _, err := Dial(context.Background(), packetConn, addr, tlsConf, config) - Expect(err).To(MatchError(testErr)) - }) - - It("closes the connection when the context is canceled", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - connRunning := make(chan struct{}) - defer close(connRunning) - conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run().Do(func() { - <-connRunning - }) - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) - newClientConnection = func( - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - return conn - } - ctx, cancel := context.WithCancel(context.Background()) - dialed := make(chan struct{}) - go func() { - defer GinkgoRecover() - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - _, err := Dial(ctx, packetConn, addr, tlsConf, config) - Expect(err).To(MatchError(context.Canceled)) - close(dialed) - }() - Consistently(dialed).ShouldNot(BeClosed()) - conn.EXPECT().shutdown() - cancel() - Eventually(dialed).Should(BeClosed()) - }) - - It("closes the connection when it was created by DialAddr", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - - var sconn sendConn - run := make(chan struct{}) - connCreated := make(chan struct{}) - conn := NewMockQUICConn(mockCtrl) - newClientConnection = func( - connP sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - sconn = connP - close(connCreated) - return conn - } - conn.EXPECT().run().Do(func() { - <-run - }) - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := DialAddr(context.Background(), "localhost:1337", tlsConf, nil) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - - Eventually(connCreated).Should(BeClosed()) - - // check that the connection is not closed - Expect(sconn.Write([]byte("foobar"))).To(Succeed()) - - manager.EXPECT().Destroy() - close(run) - time.Sleep(50 * time.Millisecond) - - Eventually(done).Should(BeClosed()) + var closed bool + cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, func() { closed = true }, true, false) + Expect(err).ToNot(HaveOccurred()) + cl.packetHandlers = manager + Expect(cl).ToNot(BeNil()) + Expect(cl.dial(context.Background())).To(MatchError(testErr)) + Expect(closed).To(BeTrue()) }) Context("quic.Config", func() { @@ -365,12 +236,6 @@ var _ = Describe("Client", func() { Expect(c.EnableDatagrams).To(BeTrue()) }) - It("errors when the Config contains an invalid version", func() { - version := protocol.VersionNumber(0x1234) - _, err := Dial(context.Background(), packetConn, nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) - Expect(err).To(MatchError("invalid QUIC version: 0x1234")) - }) - It("disables bidirectional streams", func() { config := &Config{ MaxIncomingStreams: -1, @@ -405,15 +270,12 @@ var _ = Describe("Client", func() { }) It("creates new connections with the right parameters", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(connID, gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}} + config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}} c := make(chan struct{}) var cconn sendConn var version protocol.VersionNumber var conf *Config + done := make(chan struct{}) newClientConnection = func( connP sendConn, _ connRunner, @@ -437,8 +299,15 @@ var _ = Describe("Client", func() { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().run() conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) + conn.EXPECT().destroy(gomock.Any()) + close(done) return conn } + packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) { + <-done + return 0, nil, errors.New("closed") + }) + packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes() _, err := Dial(context.Background(), packetConn, addr, tlsConf, config) Expect(err).ToNot(HaveOccurred()) Eventually(c).Should(BeClosed()) @@ -448,17 +317,12 @@ var _ = Describe("Client", func() { }) It("creates a new connections after version negotiation", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(connID, gomock.Any()).Times(2) - manager.EXPECT().Destroy() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - var counter int newClientConnection = func( _ sendConn, - _ connRunner, - _ protocol.ConnectionID, + runner connRunner, _ protocol.ConnectionID, + connID protocol.ConnectionID, configP *Config, _ *tls.Config, pn protocol.PacketNumber, @@ -474,20 +338,24 @@ var _ = Describe("Client", func() { if counter == 0 { Expect(pn).To(BeZero()) Expect(hasNegotiatedVersion).To(BeFalse()) - conn.EXPECT().run().Return(&errCloseForRecreating{ - nextPacketNumber: 109, - nextVersion: 789, + conn.EXPECT().run().DoAndReturn(func() error { + runner.Remove(connID) + return &errCloseForRecreating{ + nextPacketNumber: 109, + nextVersion: 789, + } }) } else { Expect(pn).To(Equal(protocol.PacketNumber(109))) Expect(hasNegotiatedVersion).To(BeTrue()) conn.EXPECT().run() + conn.EXPECT().destroy(gomock.Any()) } counter++ return conn } - config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}} + config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}} tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) _, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config) Expect(err).ToNot(HaveOccurred()) @@ -495,15 +363,3 @@ var _ = Describe("Client", func() { }) }) }) - -type mockConnIDGenerator struct { - ConnID protocol.ConnectionID -} - -func (m *mockConnIDGenerator) GenerateConnectionID() (protocol.ConnectionID, error) { - return m.ConnID, nil -} - -func (m *mockConnIDGenerator) ConnectionIDLen() int { - return m.ConnID.Len() -} diff --git a/integrationtests/self/cancelation_test.go b/integrationtests/self/cancelation_test.go index 63c32677f..5f95c0b73 100644 --- a/integrationtests/self/cancelation_test.go +++ b/integrationtests/self/cancelation_test.go @@ -660,6 +660,7 @@ var _ = Describe("Stream Cancellations", func() { getQuicConfig(&quic.Config{MaxIncomingStreams: maxIncomingStreams, MaxIdleTimeout: 10 * time.Second}), ) Expect(err).ToNot(HaveOccurred()) + defer server.Close() var wg sync.WaitGroup wg.Add(2 * 4 * maxIncomingStreams) diff --git a/integrationtests/self/close_test.go b/integrationtests/self/close_test.go index 31905e302..d0bcf7f01 100644 --- a/integrationtests/self/close_test.go +++ b/integrationtests/self/close_test.go @@ -24,6 +24,7 @@ var _ = Describe("Connection ID lengths tests", func() { }), ) Expect(err).ToNot(HaveOccurred()) + defer server.Close() var drop atomic.Bool dropped := make(chan []byte, 100) @@ -50,6 +51,7 @@ var _ = Describe("Connection ID lengths tests", func() { getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") sconn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index c835a2ef2..7cb1904d3 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -35,7 +35,11 @@ var _ = Describe("Connection ID lengths tests", func() { randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) } runServer := func(conf *quic.Config) *quic.Listener { - GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength))) + if conf.ConnectionIDGenerator != nil { + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", conf.ConnectionIDGenerator.ConnectionIDLen()))) + } else { + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength))) + } ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), conf) Expect(err).ToNot(HaveOccurred()) go func() { @@ -59,7 +63,11 @@ var _ = Describe("Connection ID lengths tests", func() { } runClient := func(addr net.Addr, conf *quic.Config) { - GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength))) + if conf.ConnectionIDGenerator != nil { + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", conf.ConnectionIDGenerator.ConnectionIDLen()))) + } else { + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength))) + } cl, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port), diff --git a/integrationtests/self/datagram_test.go b/integrationtests/self/datagram_test.go index 552488f46..2ccbbe169 100644 --- a/integrationtests/self/datagram_test.go +++ b/integrationtests/self/datagram_test.go @@ -22,12 +22,11 @@ var _ = Describe("Datagram test", func() { const num = 100 var ( - proxy *quicproxy.QuicProxy serverConn, clientConn *net.UDPConn dropped, total int32 ) - startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) { + startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) (port int, closeFn func()) { addr, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) serverConn, err = net.ListenUDP("udp", addr) @@ -39,8 +38,10 @@ var _ = Describe("Datagram test", func() { ) Expect(err).ToNot(HaveOccurred()) + accepted := make(chan struct{}) go func() { defer GinkgoRecover() + defer close(accepted) conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) @@ -67,7 +68,7 @@ var _ = Describe("Datagram test", func() { }() serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), // drop 10% of Short Header packets sent from the server DropPacket: func(dir quicproxy.Direction, packet []byte) bool { @@ -87,6 +88,11 @@ var _ = Describe("Datagram test", func() { }, }) Expect(err).ToNot(HaveOccurred()) + return proxy.LocalPort(), func() { + Eventually(accepted).Should(BeClosed()) + proxy.Close() + ln.Close() + } } BeforeEach(func() { @@ -96,13 +102,10 @@ var _ = Describe("Datagram test", func() { Expect(err).ToNot(HaveOccurred()) }) - AfterEach(func() { - Expect(proxy.Close()).To(Succeed()) - }) - It("sends datagrams", func() { - startServerAndProxy(true, true) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + proxyPort, close := startServerAndProxy(true, true) + defer close() + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( context.Background(), @@ -117,6 +120,7 @@ var _ = Describe("Datagram test", func() { for { // Close the connection if no message is received for 100 ms. timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() { + fmt.Println("closing conn") conn.CloseWithError(0, "") }) if _, err := conn.ReceiveMessage(); err != nil { @@ -134,11 +138,12 @@ var _ = Describe("Datagram test", func() { BeNumerically(">", expVal*9/10), BeNumerically("<", num), )) + Eventually(conn.Context().Done).Should(BeClosed()) }) It("server can disable datagram", func() { - startServerAndProxy(false, true) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + proxyPort, close := startServerAndProxy(false, true) + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( context.Background(), @@ -150,13 +155,13 @@ var _ = Describe("Datagram test", func() { Expect(err).ToNot(HaveOccurred()) Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) + close() conn.CloseWithError(0, "") - <-time.After(10 * time.Millisecond) }) It("client can disable datagram", func() { - startServerAndProxy(false, true) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + proxyPort, close := startServerAndProxy(false, true) + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( context.Background(), @@ -169,7 +174,8 @@ var _ = Describe("Datagram test", func() { Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) Expect(conn.SendMessage([]byte{0})).To(HaveOccurred()) + + close() conn.CloseWithError(0, "") - <-time.After(10 * time.Millisecond) }) }) diff --git a/integrationtests/self/early_data_test.go b/integrationtests/self/early_data_test.go index 0ce099263..136c3d0b2 100644 --- a/integrationtests/self/early_data_test.go +++ b/integrationtests/self/early_data_test.go @@ -24,6 +24,7 @@ var _ = Describe("early data", func() { getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer ln.Close() done := make(chan struct{}) go func() { defer GinkgoRecover() diff --git a/integrationtests/self/go119_test.go b/integrationtests/self/go119_test.go index cd9824dd5..c676693da 100644 --- a/integrationtests/self/go119_test.go +++ b/integrationtests/self/go119_test.go @@ -8,10 +8,9 @@ import ( "time" ) -var ( - go120 = false - errNotSupported = errors.New("not supported") -) +const go120 = false + +var errNotSupported = errors.New("not supported") func setReadDeadline(w http.ResponseWriter, deadline time.Time) error { return errNotSupported diff --git a/integrationtests/self/go120_test.go b/integrationtests/self/go120_test.go index 4ddf3c7c5..88eb4a7ed 100644 --- a/integrationtests/self/go120_test.go +++ b/integrationtests/self/go120_test.go @@ -7,7 +7,7 @@ import ( "time" ) -var go120 = true +const go120 = true func setReadDeadline(w http.ResponseWriter, deadline time.Time) error { rc := http.NewResponseController(w) diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index e1dff3210..36ea7c787 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -62,13 +62,14 @@ var _ = Describe("Handshake RTT tests", func() { runProxy(ln.Addr()) startTime := time.Now() - _, err = quic.DialAddr( + conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") expectDurationInRTTs(startTime, 2) }) @@ -79,13 +80,14 @@ var _ = Describe("Handshake RTT tests", func() { runProxy(ln.Addr()) startTime := time.Now() - _, err = quic.DialAddr( + conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") expectDurationInRTTs(startTime, 1) }) @@ -97,13 +99,14 @@ var _ = Describe("Handshake RTT tests", func() { runProxy(ln.Addr()) startTime := time.Now() - _, err = quic.DialAddr( + conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") expectDurationInRTTs(startTime, 2) }) @@ -131,6 +134,7 @@ var _ = Describe("Handshake RTT tests", func() { getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) @@ -166,6 +170,7 @@ var _ = Describe("Handshake RTT tests", func() { getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 3274b84cb..b3a13e9d4 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -114,7 +114,7 @@ var _ = Describe("Handshake tests", func() { context.Background(), fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - nil, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) str, err := conn.AcceptStream(context.Background()) @@ -223,13 +223,14 @@ var _ = Describe("Handshake tests", func() { var ( server *quic.Listener pconn net.PacketConn + dialer *quic.Transport ) dial := func() (quic.Connection, error) { remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) raddr, err := net.ResolveUDPAddr("udp", remoteAddr) Expect(err).ToNot(HaveOccurred()) - return quic.Dial(context.Background(), pconn, raddr, getTLSClientConfig(), nil) + return dialer.Dial(context.Background(), raddr, getTLSClientConfig(), getQuicConfig(nil)) } BeforeEach(func() { @@ -243,11 +244,13 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) pconn, err = net.ListenUDP("udp", laddr) Expect(err).ToNot(HaveOccurred()) + dialer = &quic.Transport{Conn: pconn} }) AfterEach(func() { Expect(server.Close()).To(Succeed()) Expect(pconn.Close()).To(Succeed()) + Expect(dialer.Close()).To(Succeed()) }) It("rejects new connection attempts if connections don't get accepted", func() { @@ -366,6 +369,7 @@ var _ = Describe("Handshake tests", func() { It("uses tokens provided in NEW_TOKEN frames", func() { server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) + defer server.Close() // dial the first connection and receive the token go func() { diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 4e384d9fb..918bfdf25 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -382,6 +382,7 @@ var _ = Describe("HTTP tests", func() { tlsConf.NextProtos = []string{"h3"} ln, err := quic.ListenAddr("localhost:0", tlsConf, nil) Expect(err).ToNot(HaveOccurred()) + defer ln.Close() done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -398,57 +399,51 @@ var _ = Describe("HTTP tests", func() { Eventually(done).Should(BeClosed()) }) - It("supports read deadlines", func() { - if !go120 { - Skip("This test requires Go 1.20+") - } + if go120 { + It("supports read deadlines", func() { + mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + err := setReadDeadline(w, time.Now().Add(deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) - mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - err := setReadDeadline(w, time.Now().Add(deadlineDelay)) + body, err := io.ReadAll(r.Body) + Expect(err).To(MatchError(os.ErrDeadlineExceeded)) + Expect(body).To(ContainSubstring("aa")) + + w.Write([]byte("ok")) + }) + + expectedEnd := time.Now().Add(deadlineDelay) + resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a')) Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) - body, err := io.ReadAll(r.Body) - Expect(err).To(MatchError(os.ErrDeadlineExceeded)) - Expect(body).To(ContainSubstring("aa")) - - w.Write([]byte("ok")) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + Expect(time.Now().After(expectedEnd)).To(BeTrue()) + Expect(string(body)).To(Equal("ok")) }) - expectedEnd := time.Now().Add(deadlineDelay) - resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a')) - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) + It("supports write deadlines", func() { + mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + err := setWriteDeadline(w, time.Now().Add(deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) - Expect(err).ToNot(HaveOccurred()) - Expect(time.Now().After(expectedEnd)).To(BeTrue()) - Expect(string(body)).To(Equal("ok")) - }) + _, err = io.Copy(w, neverEnding('a')) + Expect(err).To(MatchError(os.ErrDeadlineExceeded)) + }) - It("supports write deadlines", func() { - if !go120 { - Skip("This test requires Go 1.20+") - } + expectedEnd := time.Now().Add(deadlineDelay) - mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - err := setWriteDeadline(w, time.Now().Add(deadlineDelay)) + resp, err := client.Get("https://localhost:" + port + "/write-deadline") Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) - _, err = io.Copy(w, neverEnding('a')) - Expect(err).To(MatchError(os.ErrDeadlineExceeded)) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + Expect(time.Now().After(expectedEnd)).To(BeTrue()) + Expect(string(body)).To(ContainSubstring("aa")) }) - - expectedEnd := time.Now().Add(deadlineDelay) - - resp, err := client.Get("https://localhost:" + port + "/write-deadline") - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) - Expect(err).ToNot(HaveOccurred()) - Expect(time.Now().After(expectedEnd)).To(BeTrue()) - Expect(string(body)).To(ContainSubstring("aa")) - }) + } }) diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index b75d16560..dcac1b469 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -34,10 +34,9 @@ var _ = Describe("Multiplexing", func() { }() } - dial := func(pconn net.PacketConn, addr net.Addr) { - conn, err := quic.Dial( + dial := func(tr *quic.Transport, addr net.Addr) { + conn, err := tr.Dial( context.Background(), - pconn, addr, getTLSClientConfig(), getQuicConfig(nil), @@ -72,17 +71,18 @@ var _ = Describe("Multiplexing", func() { conn, err := net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) defer conn.Close() + tr := &quic.Transport{Conn: conn} done1 := make(chan struct{}) done2 := make(chan struct{}) go func() { defer GinkgoRecover() - dial(conn, server.Addr()) + dial(tr, server.Addr()) close(done1) }() go func() { defer GinkgoRecover() - dial(conn, server.Addr()) + dial(tr, server.Addr()) close(done2) }() timeout := 30 * time.Second @@ -106,17 +106,18 @@ var _ = Describe("Multiplexing", func() { conn, err := net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) defer conn.Close() + tr := &quic.Transport{Conn: conn} done1 := make(chan struct{}) done2 := make(chan struct{}) go func() { defer GinkgoRecover() - dial(conn, server1.Addr()) + dial(tr, server1.Addr()) close(done1) }() go func() { defer GinkgoRecover() - dial(conn, server2.Addr()) + dial(tr, server2.Addr()) close(done2) }() timeout := 30 * time.Second @@ -135,9 +136,9 @@ var _ = Describe("Multiplexing", func() { conn, err := net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) defer conn.Close() + tr := &quic.Transport{Conn: conn} - server, err := quic.Listen( - conn, + server, err := tr.Listen( getTLSConfig(), getQuicConfig(nil), ) @@ -146,7 +147,7 @@ var _ = Describe("Multiplexing", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - dial(conn, server.Addr()) + dial(tr, server.Addr()) close(done) }() timeout := 30 * time.Second @@ -165,15 +166,16 @@ var _ = Describe("Multiplexing", func() { conn1, err := net.ListenUDP("udp", addr1) Expect(err).ToNot(HaveOccurred()) defer conn1.Close() + tr1 := &quic.Transport{Conn: conn1} addr2, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) conn2, err := net.ListenUDP("udp", addr2) Expect(err).ToNot(HaveOccurred()) defer conn2.Close() + tr2 := &quic.Transport{Conn: conn2} - server1, err := quic.Listen( - conn1, + server1, err := tr1.Listen( getTLSConfig(), getQuicConfig(nil), ) @@ -181,8 +183,7 @@ var _ = Describe("Multiplexing", func() { runServer(server1) defer server1.Close() - server2, err := quic.Listen( - conn2, + server2, err := tr2.Listen( getTLSConfig(), getQuicConfig(nil), ) @@ -194,12 +195,12 @@ var _ = Describe("Multiplexing", func() { done2 := make(chan struct{}) go func() { defer GinkgoRecover() - dial(conn2, server1.Addr()) + dial(tr2, server1.Addr()) close(done1) }() go func() { defer GinkgoRecover() - dial(conn1, server2.Addr()) + dial(tr1, server2.Addr()) close(done2) }() timeout := 30 * time.Second diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index 1338b30ca..86062bd56 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -31,8 +31,8 @@ var _ = Describe("Packetization", func() { }), ) Expect(err).ToNot(HaveOccurred()) - serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) defer server.Close() + serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: serverAddr, @@ -54,6 +54,7 @@ var _ = Describe("Packetization", func() { }), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") go func() { defer GinkgoRecover() diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index 9dfafa317..966a57158 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -199,8 +199,16 @@ func areHandshakesRunning() bool { return strings.Contains(b.String(), "RunHandshake") } +func areTransportsRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "quic-go.(*Transport).listen") +} + var _ = AfterEach(func() { Expect(areHandshakesRunning()).To(BeFalse()) + Eventually(areTransportsRunning).Should(BeFalse()) + if debugLog() { logFile, err := os.Create(logFileName) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index 98b21b2a6..cc5afe903 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -2,7 +2,6 @@ package self_test import ( "context" - "errors" "fmt" "math/rand" "net" @@ -27,7 +26,13 @@ var _ = Describe("Stateless Resets", func() { rand.Read(statelessResetKey[:]) serverConfig := getQuicConfig(&quic.Config{StatelessResetKey: &statelessResetKey}) - ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) + c, err := net.ListenUDP("udp", nil) + Expect(err).ToNot(HaveOccurred()) + tr := &quic.Transport{ + Conn: c, + } + defer tr.Close() + ln, err := tr.Listen(getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) serverPort := ln.Addr().(*net.UDPAddr).Port @@ -42,7 +47,8 @@ var _ = Describe("Stateless Resets", func() { _, err = str.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) <-closeServer - ln.Close() + Expect(ln.Close()).To(Succeed()) + Expect(tr.Close()).To(Succeed()) }() var drop atomic.Bool @@ -77,11 +83,14 @@ var _ = Describe("Stateless Resets", func() { close(closeServer) time.Sleep(100 * time.Millisecond) - ln2, err := quic.ListenAddr( - fmt.Sprintf("localhost:%d", serverPort), - getTLSConfig(), - serverConfig, - ) + // We need to create a new Transport here, since the old one is still sending out + // CONNECTION_CLOSE packets for (recently) closed connections). + tr2 := &quic.Transport{ + Conn: c, + StatelessResetKey: &statelessResetKey, + } + defer tr2.Close() + ln2, err := tr2.Listen(getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) drop.Store(false) @@ -100,8 +109,7 @@ var _ = Describe("Stateless Resets", func() { _, serr = str.Read([]byte{0}) } Expect(serr).To(HaveOccurred()) - statelessResetErr := &quic.StatelessResetError{} - Expect(errors.As(serr, &statelessResetErr)).To(BeTrue()) + Expect(serr).To(BeAssignableToTypeOf(&quic.StatelessResetError{})) Expect(ln2.Close()).To(Succeed()) Eventually(acceptStopped).Should(BeClosed()) }) diff --git a/integrationtests/self/stream_test.go b/integrationtests/self/stream_test.go index 0af14b8fe..332cd505e 100644 --- a/integrationtests/self/stream_test.go +++ b/integrationtests/self/stream_test.go @@ -94,6 +94,8 @@ var _ = Describe("Bidirectional streams", func() { ) Expect(err).ToNot(HaveOccurred()) runSendingPeer(client) + client.CloseWithError(0, "") + <-conn.Context().Done() }) It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { @@ -149,5 +151,6 @@ var _ = Describe("Bidirectional streams", func() { runReceivingPeer(client) <-done1 <-done2 + client.CloseWithError(0, "") }) }) diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 5996a5340..abc05dd79 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -473,6 +473,7 @@ var _ = Describe("Timeout tests", func() { }), ) Expect(err).ToNot(HaveOccurred()) + defer ln.Close() serverErrChan := make(chan error, 1) go func() { diff --git a/integrationtests/self/uni_stream_test.go b/integrationtests/self/uni_stream_test.go index 9253b7012..a2fe4e501 100644 --- a/integrationtests/self/uni_stream_test.go +++ b/integrationtests/self/uni_stream_test.go @@ -88,11 +88,14 @@ var _ = Describe("Unidirectional Streams", func() { }) It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { + done := make(chan struct{}) go func() { defer GinkgoRecover() + defer close(done) conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) runSendingPeer(conn) + <-conn.Context().Done() }() client, err := quic.DialAddr( @@ -103,6 +106,7 @@ var _ = Describe("Unidirectional Streams", func() { ) Expect(err).ToNot(HaveOccurred()) runReceivingPeer(client) + client.CloseWithError(0, "") }) It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() { diff --git a/mock_multiplexer_test.go b/mock_multiplexer_test.go deleted file mode 100644 index 0383b1a4d..000000000 --- a/mock_multiplexer_test.go +++ /dev/null @@ -1,65 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go (interfaces: Multiplexer) - -// Package quic is a generated GoMock package. -package quic - -import ( - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - logging "github.com/quic-go/quic-go/logging" -) - -// MockMultiplexer is a mock of Multiplexer interface. -type MockMultiplexer struct { - ctrl *gomock.Controller - recorder *MockMultiplexerMockRecorder -} - -// MockMultiplexerMockRecorder is the mock recorder for MockMultiplexer. -type MockMultiplexerMockRecorder struct { - mock *MockMultiplexer -} - -// NewMockMultiplexer creates a new mock instance. -func NewMockMultiplexer(ctrl *gomock.Controller) *MockMultiplexer { - mock := &MockMultiplexer{ctrl: ctrl} - mock.recorder = &MockMultiplexerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder { - return m.recorder -} - -// AddConn mocks base method. -func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int, arg2 *StatelessResetKey, arg3 logging.Tracer) (packetHandlerManager, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddConn", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(packetHandlerManager) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AddConn indicates an expected call of AddConn. -func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1, arg2, arg3) -} - -// RemoveConn mocks base method. -func (m *MockMultiplexer) RemoveConn(arg0 indexableConn) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveConn", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemoveConn indicates an expected call of RemoveConn. -func (mr *MockMultiplexerMockRecorder) RemoveConn(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveConn", reflect.TypeOf((*MockMultiplexer)(nil).RemoveConn), arg0) -} diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index bcd160382..25ae5420c 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -74,6 +74,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 i return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2) } +// Close mocks base method. +func (m *MockPacketHandlerManager) Close(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close", arg0) +} + +// Close indicates an expected call of Close. +func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0) +} + // CloseServer mocks base method. func (m *MockPacketHandlerManager) CloseServer() { m.ctrl.T.Helper() @@ -86,20 +98,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer)) } -// Destroy mocks base method. -func (m *MockPacketHandlerManager) Destroy() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Destroy") - ret0, _ := ret[0].(error) - return ret0 -} - -// Destroy indicates an expected call of Destroy. -func (mr *MockPacketHandlerManagerMockRecorder) Destroy() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockPacketHandlerManager)(nil).Destroy)) -} - // Get mocks base method. func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) { m.ctrl.T.Helper() @@ -115,6 +113,21 @@ func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0) } +// GetByResetToken mocks base method. +func (m *MockPacketHandlerManager) GetByResetToken(arg0 protocol.StatelessResetToken) (packetHandler, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetByResetToken", arg0) + ret0, _ := ret[0].(packetHandler) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetByResetToken indicates an expected call of GetByResetToken. +func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0) +} + // GetStatelessResetToken mocks base method. func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken { m.ctrl.T.Helper() @@ -176,15 +189,3 @@ func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 interface{}) *gomock mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0) } - -// SetServer mocks base method. -func (m *MockPacketHandlerManager) SetServer(arg0 unknownPacketHandler) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetServer", arg0) -} - -// SetServer indicates an expected call of SetServer. -func (mr *MockPacketHandlerManagerMockRecorder) SetServer(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).SetServer), arg0) -} diff --git a/mockgen.go b/mockgen.go index 443e9c10c..eb700864a 100644 --- a/mockgen.go +++ b/mockgen.go @@ -65,9 +65,6 @@ type UnknownPacketHandler = unknownPacketHandler //go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" type PacketHandlerManager = packetHandlerManager -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_multiplexer_test.go github.com/quic-go/quic-go Multiplexer" -type Multiplexer = multiplexer - // Need to use source mode for the batchConn, since reflect mode follows type aliases. // See https://github.com/golang/mock/issues/244 for details. // diff --git a/multiplexer.go b/multiplexer.go index 37d4e75cf..85f7f4034 100644 --- a/multiplexer.go +++ b/multiplexer.go @@ -6,7 +6,6 @@ import ( "sync" "github.com/quic-go/quic-go/internal/utils" - "github.com/quic-go/quic-go/logging" ) var ( @@ -14,30 +13,19 @@ var ( connMuxer multiplexer ) -type indexableConn interface { - LocalAddr() net.Addr -} +type indexableConn interface{ LocalAddr() net.Addr } type multiplexer interface { - AddConn(c net.PacketConn, connIDLen int, statelessResetKey *StatelessResetKey, tracer logging.Tracer) (packetHandlerManager, error) + AddConn(conn indexableConn) RemoveConn(indexableConn) error } -type connManager struct { - connIDLen int - statelessResetKey *StatelessResetKey - tracer logging.Tracer - manager packetHandlerManager -} - // The connMultiplexer listens on multiple net.PacketConns and dispatches // incoming packets to the connection handler. type connMultiplexer struct { mutex sync.Mutex - conns map[string] /* LocalAddr().String() */ connManager - newPacketHandlerManager func(net.PacketConn, int, *StatelessResetKey, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests - + conns map[string] /* LocalAddr().String() */ indexableConn logger utils.Logger } @@ -46,57 +34,38 @@ var _ multiplexer = &connMultiplexer{} func getMultiplexer() multiplexer { connMuxerOnce.Do(func() { connMuxer = &connMultiplexer{ - conns: make(map[string]connManager), - logger: utils.DefaultLogger.WithPrefix("muxer"), - newPacketHandlerManager: newPacketHandlerMap, + conns: make(map[string]indexableConn), + logger: utils.DefaultLogger.WithPrefix("muxer"), } }) return connMuxer } -func (m *connMultiplexer) AddConn( - c net.PacketConn, - connIDLen int, - statelessResetKey *StatelessResetKey, - tracer logging.Tracer, -) (packetHandlerManager, error) { +func (m *connMultiplexer) index(addr net.Addr) string { + return addr.Network() + " " + addr.String() +} + +func (m *connMultiplexer) AddConn(c indexableConn) { m.mutex.Lock() defer m.mutex.Unlock() - addr := c.LocalAddr() - connIndex := addr.Network() + " " + addr.String() + connIndex := m.index(c.LocalAddr()) p, ok := m.conns[connIndex] - if !ok { - manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger) - if err != nil { - return nil, err - } - p = connManager{ - connIDLen: connIDLen, - statelessResetKey: statelessResetKey, - manager: manager, - tracer: tracer, - } - m.conns[connIndex] = p - } else { - if p.connIDLen != connIDLen { - return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) - } - if statelessResetKey != nil && p.statelessResetKey != statelessResetKey { - return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn") - } - if tracer != p.tracer { - return nil, fmt.Errorf("cannot use different tracers on the same packet conn") - } + if ok { + // Panics if we're already listening on this connection. + // This is a safeguard because we're introducing a breaking API change, see + // https://github.com/quic-go/quic-go/issues/3727 for details. + // We'll remove this at a later time, when most users of the library have made the switch. + panic("connection already exists") // TODO: write a nice message } - return p.manager, nil + m.conns[connIndex] = p } func (m *connMultiplexer) RemoveConn(c indexableConn) error { m.mutex.Lock() defer m.mutex.Unlock() - connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String() + connIndex := m.index(c.LocalAddr()) if _, ok := m.conns[connIndex]; !ok { return fmt.Errorf("cannote remove connection, connection is unknown") } diff --git a/multiplexer_test.go b/multiplexer_test.go index 3730cc334..48590b28e 100644 --- a/multiplexer_test.go +++ b/multiplexer_test.go @@ -3,71 +3,24 @@ package quic import ( "net" - mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" - - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) -type testConn struct { - counter int - net.PacketConn -} - var _ = Describe("Multiplexer", func() { - It("adds a new packet conn ", func() { - conn := NewMockPacketConn(mockCtrl) - conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}) - _, err := getMultiplexer().AddConn(conn, 8, nil, nil) - Expect(err).ToNot(HaveOccurred()) + It("adds new packet conns", func() { + conn1 := NewMockPacketConn(mockCtrl) + conn1.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}) + getMultiplexer().AddConn(conn1) + conn2 := NewMockPacketConn(mockCtrl) + conn2.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1235}) + getMultiplexer().AddConn(conn2) }) - It("recognizes when the same connection is added twice", func() { - srk := &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'} - pconn := NewMockPacketConn(mockCtrl) - pconn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2) - pconn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn := testConn{PacketConn: pconn} - tracer := mocklogging.NewMockTracer(mockCtrl) - _, err := getMultiplexer().AddConn(conn, 8, srk, tracer) - Expect(err).ToNot(HaveOccurred()) - conn.counter++ - _, err = getMultiplexer().AddConn(conn, 8, srk, tracer) - Expect(err).ToNot(HaveOccurred()) - Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1)) - }) - - It("errors when adding an existing conn with a different connection ID length", func() { + It("panics when the same connection is added twice", func() { conn := NewMockPacketConn(mockCtrl) - conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) - _, err := getMultiplexer().AddConn(conn, 5, nil, nil) - Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 6, nil, nil) - Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs")) - }) - - It("errors when adding an existing conn with a different stateless rest key", func() { - srk1 := &StatelessResetKey{'f', 'o', 'o'} - srk2 := &StatelessResetKey{'b', 'a', 'r'} - conn := NewMockPacketConn(mockCtrl) - conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) - _, err := getMultiplexer().AddConn(conn, 7, srk1, nil) - Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 7, srk2, nil) - Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn")) - }) - - It("errors when adding an existing conn with different tracers", func() { - conn := NewMockPacketConn(mockCtrl) - conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) - _, err := getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl)) - Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl)) - Expect(err).To(MatchError("cannot use different tracers on the same packet conn")) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2) + getMultiplexer().AddConn(conn) + Expect(func() { getMultiplexer().AddConn(conn) }).To(Panic()) }) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index 99e2bfb18..2a08359a9 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -5,28 +5,22 @@ import ( "crypto/rand" "crypto/sha256" "errors" - "fmt" "hash" "io" - "log" "net" - "os" - "strconv" - "strings" "sync" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" - "github.com/quic-go/quic-go/internal/wire" - "github.com/quic-go/quic-go/logging" ) -// rawConn is a connection that allow reading of a receivedPacket. +// rawConn is a connection that allow reading of a receivedPackeh. type rawConn interface { ReadPacket() (*receivedPacket, error) WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) LocalAddr() net.Addr + SetReadDeadline(time.Time) error io.Closer } @@ -36,113 +30,49 @@ type closePacket struct { info *packetInfo } -// The packetHandlerMap stores packetHandlers, identified by connection ID. -// It is used: -// * by the server to store connections -// * when multiplexing outgoing connections to store clients +type unknownPacketHandler interface { + handlePacket(*receivedPacket) + setCloseError(error) +} + +var errListenerAlreadySet = errors.New("listener already set") + type packetHandlerMap struct { - mutex sync.Mutex - - conn rawConn - connIDLen int - - closeQueue chan closePacket - + mutex sync.Mutex handlers map[protocol.ConnectionID]packetHandler resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler - server unknownPacketHandler - listening chan struct{} // is closed when listen returns closed bool + closeChan chan struct{} + + enqueueClosePacket func(closePacket) deleteRetiredConnsAfter time.Duration - statelessResetEnabled bool - statelessResetMutex sync.Mutex - statelessResetHasher hash.Hash + statelessResetMutex sync.Mutex + statelessResetHasher hash.Hash - tracer logging.Tracer logger utils.Logger } var _ packetHandlerManager = &packetHandlerMap{} -func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { - conn, ok := c.(interface{ SetReadBuffer(int) error }) - if !ok { - return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") - } - size, err := inspectReadBuffer(c) - if err != nil { - return fmt.Errorf("failed to determine receive buffer size: %w", err) - } - if size >= protocol.DesiredReceiveBufferSize { - logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) - return nil - } - if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { - return fmt.Errorf("failed to increase receive buffer size: %w", err) - } - newSize, err := inspectReadBuffer(c) - if err != nil { - return fmt.Errorf("failed to determine receive buffer size: %w", err) - } - if newSize == size { - return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024) - } - if newSize < protocol.DesiredReceiveBufferSize { - return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024) - } - logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024) - return nil -} - -// only print warnings about the UDP receive buffer size once -var receiveBufferWarningOnce sync.Once - -func newPacketHandlerMap( - c net.PacketConn, - connIDLen int, - statelessResetKey *StatelessResetKey, - tracer logging.Tracer, - logger utils.Logger, -) (packetHandlerManager, error) { - if err := setReceiveBuffer(c, logger); err != nil { - if !strings.Contains(err.Error(), "use of closed network connection") { - receiveBufferWarningOnce.Do(func() { - if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { - return - } - log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err) - }) - } - } - conn, err := wrapConn(c) - if err != nil { - return nil, err - } - m := &packetHandlerMap{ - conn: conn, - connIDLen: connIDLen, - listening: make(chan struct{}), +func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap { + h := &packetHandlerMap{ + closeChan: make(chan struct{}), handlers: make(map[protocol.ConnectionID]packetHandler), resetTokens: make(map[protocol.StatelessResetToken]packetHandler), deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, - closeQueue: make(chan closePacket, 4), - statelessResetEnabled: statelessResetKey != nil, - tracer: tracer, + enqueueClosePacket: enqueueClosePacket, logger: logger, } - if m.statelessResetEnabled { - m.statelessResetHasher = hmac.New(sha256.New, statelessResetKey[:]) + if key != nil { + h.statelessResetHasher = hmac.New(sha256.New, key[:]) } - go m.listen() - go m.runCloseQueue() - - if logger.Debug() { - go m.logUsage() + if h.logger.Debug() { + go h.logUsage() } - return m, nil + return h } func (h *packetHandlerMap) logUsage() { @@ -150,7 +80,7 @@ func (h *packetHandlerMap) logUsage() { var printedZero bool for { select { - case <-h.listening: + case <-h.closeChan: return case <-ticker.C: } @@ -233,12 +163,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p if connClosePacket != nil { handler = newClosedLocalConn( func(addr net.Addr, info *packetInfo) { - select { - case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}: - default: - // Oops, we're backlogged. - // Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway. - } + h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info}) }, pers, h.logger, @@ -265,17 +190,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p }) } -func (h *packetHandlerMap) runCloseQueue() { - for { - select { - case <-h.listening: - return - case p := <-h.closeQueue: - h.conn.WritePacket(p.payload, p.addr, p.info.OOB()) - } - } -} - func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { h.mutex.Lock() h.resetTokens[token] = handler @@ -288,19 +202,16 @@ func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) h.mutex.Unlock() } -func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { +func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) { h.mutex.Lock() - h.server = s - h.mutex.Unlock() + defer h.mutex.Unlock() + + handler, ok := h.resetTokens[token] + return handler, ok } func (h *packetHandlerMap) CloseServer() { h.mutex.Lock() - if h.server == nil { - h.mutex.Unlock() - return - } - h.server = nil var wg sync.WaitGroup for _, handler := range h.handlers { if handler.getPerspective() == protocol.PerspectiveServer { @@ -316,23 +227,16 @@ func (h *packetHandlerMap) CloseServer() { wg.Wait() } -// Destroy closes the underlying connection and waits until listen() has returned. -// It does not close active connections. -func (h *packetHandlerMap) Destroy() error { - if err := h.conn.Close(); err != nil { - return err - } - <-h.listening // wait until listening returns - return nil -} - -func (h *packetHandlerMap) close(e error) error { +func (h *packetHandlerMap) Close(e error) { h.mutex.Lock() + if h.closed { h.mutex.Unlock() - return nil + return } + close(h.closeChan) + var wg sync.WaitGroup for _, handler := range h.handlers { wg.Add(1) @@ -341,89 +245,14 @@ func (h *packetHandlerMap) close(e error) error { wg.Done() }(handler) } - - if h.server != nil { - h.server.setCloseError(e) - } h.closed = true h.mutex.Unlock() wg.Wait() - return getMultiplexer().RemoveConn(h.conn) -} - -func (h *packetHandlerMap) listen() { - defer close(h.listening) - for { - p, err := h.conn.ReadPacket() - //nolint:staticcheck // SA1019 ignore this! - // TODO: This code is used to ignore wsa errors on Windows. - // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. - // See https://github.com/quic-go/quic-go/issues/1737 for details. - if nerr, ok := err.(net.Error); ok && nerr.Temporary() { - h.logger.Debugf("Temporary error reading from conn: %w", err) - continue - } - if err != nil { - h.close(err) - return - } - h.handlePacket(p) - } -} - -func (h *packetHandlerMap) handlePacket(p *receivedPacket) { - connID, err := wire.ParseConnectionID(p.data, h.connIDLen) - if err != nil { - h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) - if h.tracer != nil { - h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) - } - p.buffer.MaybeRelease() - return - } - - h.mutex.Lock() - defer h.mutex.Unlock() - - if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset { - return - } - if handler, ok := h.handlers[connID]; ok { - handler.handlePacket(p) - return - } - if !wire.IsLongHeaderPacket(p.data[0]) { - go h.maybeSendStatelessReset(p, connID) - return - } - if h.server == nil { // no server set - h.logger.Debugf("received a packet with an unexpected connection ID %s", connID) - return - } - h.server.handlePacket(p) -} - -func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { - // stateless resets are always short header packets - if wire.IsLongHeaderPacket(data[0]) { - return false - } - if len(data) < 17 /* type byte + 16 bytes for the reset token */ { - return false - } - - token := *(*protocol.StatelessResetToken)(data[len(data)-16:]) - if sess, ok := h.resetTokens[token]; ok { - h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) - go sess.destroy(&StatelessResetError{Token: token}) - return true - } - return false } func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { var token protocol.StatelessResetToken - if !h.statelessResetEnabled { + if h.statelessResetHasher == nil { // Return a random stateless reset token. // This token will be sent in the server's transport parameters. // By using a random token, an off-path attacker won't be able to disrupt the connection. @@ -437,24 +266,3 @@ func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) h.statelessResetMutex.Unlock() return token } - -func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { - defer p.buffer.Release() - if !h.statelessResetEnabled { - return - } - // Don't send a stateless reset in response to very small packets. - // This includes packets that could be stateless resets. - if len(p.data) <= protocol.MinStatelessResetSize { - return - } - token := h.GetStatelessResetToken(connID) - h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) - data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) - rand.Read(data) - data[0] = (data[0] & 0x7f) | 0x40 - data = append(data, token[:]...) - if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { - h.logger.Debugf("Error sending Stateless Reset: %s", err) - } -} diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 0397e3f0a..e87a75f88 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -6,405 +6,188 @@ import ( "net" "time" - mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" - "github.com/quic-go/quic-go/internal/wire" - "github.com/quic-go/quic-go/logging" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("Packet Handler Map", func() { - type packetToRead struct { - addr net.Addr - data []byte - err error - } - - var ( - handler *packetHandlerMap - conn *MockPacketConn - tracer *mocklogging.MockTracer - packetChan chan packetToRead - - connIDLen int - statelessResetKey *StatelessResetKey - ) - - getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { - b, err := (&wire.ExtendedHeader{ - Header: wire.Header{ - Type: t, - DestConnectionID: connID, - Length: length, - Version: protocol.Version1, - }, - PacketNumberLen: protocol.PacketNumberLen2, - }).Append(nil, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - return b - } - - getPacket := func(connID protocol.ConnectionID) []byte { - return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2) - } - - BeforeEach(func() { - statelessResetKey = nil - connIDLen = 0 - tracer = mocklogging.NewMockTracer(mockCtrl) - packetChan = make(chan packetToRead, 10) + It("adds and gets", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + handler := NewMockPacketHandler(mockCtrl) + Expect(m.Add(connID, handler)).To(BeTrue()) + h, ok := m.Get(connID) + Expect(ok).To(BeTrue()) + Expect(h).To(Equal(handler)) }) - JustBeforeEach(func() { - conn = NewMockPacketConn(mockCtrl) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() - conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { - p, ok := <-packetChan - if !ok { - return 0, nil, errors.New("closed") + It("refused to add duplicates", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + handler := NewMockPacketHandler(mockCtrl) + Expect(m.Add(connID, handler)).To(BeTrue()) + Expect(m.Add(connID, handler)).To(BeFalse()) + }) + + It("removes", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + handler := NewMockPacketHandler(mockCtrl) + Expect(m.Add(connID, handler)).To(BeTrue()) + m.Remove(connID) + _, ok := m.Get(connID) + Expect(ok).To(BeFalse()) + Expect(m.Add(connID, handler)).To(BeTrue()) + }) + + It("retires", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + dur := scaleDuration(50 * time.Millisecond) + m.deleteRetiredConnsAfter = dur + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + handler := NewMockPacketHandler(mockCtrl) + Expect(m.Add(connID, handler)).To(BeTrue()) + m.Retire(connID) + _, ok := m.Get(connID) + Expect(ok).To(BeTrue()) + time.Sleep(dur) + Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse()) + }) + + It("adds newly to-be-constructed handlers", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + var called bool + connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) + Expect(m.AddWithConnID(connID1, connID2, func() packetHandler { + called = true + return NewMockPacketHandler(mockCtrl) + })).To(BeTrue()) + Expect(called).To(BeTrue()) + Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() packetHandler { + Fail("didn't expect the constructor to be executed") + return nil + })).To(BeFalse()) + }) + + It("adds, gets and removes reset tokens", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} + handler := NewMockPacketHandler(mockCtrl) + m.AddResetToken(token, handler) + h, ok := m.GetByResetToken(token) + Expect(ok).To(BeTrue()) + Expect(h).To(Equal(h)) + m.RemoveResetToken(token) + _, ok = m.GetByResetToken(token) + Expect(ok).To(BeFalse()) + }) + + It("generates stateless reset token, if no key is set", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + b := make([]byte, 8) + rand.Read(b) + connID := protocol.ParseConnectionID(b) + token := m.GetStatelessResetToken(connID) + for i := 0; i < 1000; i++ { + to := m.GetStatelessResetToken(connID) + Expect(to).ToNot(Equal(token)) + token = to + } + }) + + It("generates stateless reset token, if a key is set", func() { + var key StatelessResetKey + rand.Read(key[:]) + m := newPacketHandlerMap(&key, nil, utils.DefaultLogger) + b := make([]byte, 8) + rand.Read(b) + connID := protocol.ParseConnectionID(b) + token := m.GetStatelessResetToken(connID) + Expect(token).ToNot(BeZero()) + Expect(m.GetStatelessResetToken(connID)).To(Equal(token)) + // generate a new connection ID + rand.Read(b) + connID2 := protocol.ParseConnectionID(b) + Expect(m.GetStatelessResetToken(connID2)).ToNot(Equal(token)) + }) + + It("replaces locally closed connections", func() { + var closePackets []closePacket + m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger) + dur := scaleDuration(50 * time.Millisecond) + m.deleteRetiredConnsAfter = dur + + handler := NewMockPacketHandler(mockCtrl) + connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) + Expect(m.Add(connID, handler)).To(BeTrue()) + m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, []byte("foobar")) + h, ok := m.Get(connID) + Expect(ok).To(BeTrue()) + Expect(h).ToNot(Equal(handler)) + addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} + h.handlePacket(&receivedPacket{remoteAddr: addr}) + Expect(closePackets).To(HaveLen(1)) + Expect(closePackets[0].addr).To(Equal(addr)) + Expect(closePackets[0].payload).To(Equal([]byte("foobar"))) + + time.Sleep(dur) + Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse()) + }) + + It("replaces remote closed connections", func() { + var closePackets []closePacket + m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger) + dur := scaleDuration(50 * time.Millisecond) + m.deleteRetiredConnsAfter = dur + + handler := NewMockPacketHandler(mockCtrl) + connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) + Expect(m.Add(connID, handler)).To(BeTrue()) + m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, nil) + h, ok := m.Get(connID) + Expect(ok).To(BeTrue()) + Expect(h).ToNot(Equal(handler)) + addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} + h.handlePacket(&receivedPacket{remoteAddr: addr}) + Expect(closePackets).To(BeEmpty()) + + time.Sleep(dur) + Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse()) + }) + + It("closes the server", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + for i := 0; i < 10; i++ { + conn := NewMockPacketHandler(mockCtrl) + if i%2 == 0 { + conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient) + } else { + conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer) + conn.EXPECT().shutdown() } - return copy(b, p.data), p.addr, p.err - }).AnyTimes() - phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger) - Expect(err).ToNot(HaveOccurred()) - handler = phm.(*packetHandlerMap) + b := make([]byte, 12) + rand.Read(b) + m.Add(protocol.ParseConnectionID(b), conn) + } + m.CloseServer() }) It("closes", func() { - getMultiplexer() // make the sync.Once execute - // replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer - mockMultiplexer := NewMockMultiplexer(mockCtrl) - origMultiplexer := connMuxer - connMuxer = mockMultiplexer - - defer func() { - connMuxer = origMultiplexer - }() - - testErr := errors.New("test error ") - conn1 := NewMockPacketHandler(mockCtrl) - conn1.EXPECT().destroy(testErr) - conn2 := NewMockPacketHandler(mockCtrl) - conn2.EXPECT().destroy(testErr) - handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), conn1) - handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), conn2) - mockMultiplexer.EXPECT().RemoveConn(gomock.Any()) - handler.close(testErr) - close(packetChan) - Eventually(handler.listening).Should(BeClosed()) - }) - - Context("other operations", func() { - AfterEach(func() { - // delete connections and the server before closing - // They might be mock implementations, and we'd have to register the expected calls before otherwise. - handler.mutex.Lock() - for connID := range handler.handlers { - delete(handler.handlers, connID) - } - handler.server = nil - handler.mutex.Unlock() - conn.EXPECT().Close().MaxTimes(1) - close(packetChan) - handler.Destroy() - Eventually(handler.listening).Should(BeClosed()) - }) - - Context("handling packets", func() { - BeforeEach(func() { - connIDLen = 5 - }) - - It("handles packets for different packet handlers on the same packet conn", func() { - connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) - packetHandler1 := NewMockPacketHandler(mockCtrl) - packetHandler2 := NewMockPacketHandler(mockCtrl) - handledPacket1 := make(chan struct{}) - handledPacket2 := make(chan struct{}) - packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - connID, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(connID1)) - close(handledPacket1) - }) - packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - connID, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(connID2)) - close(handledPacket2) - }) - handler.Add(connID1, packetHandler1) - handler.Add(connID2, packetHandler2) - packetChan <- packetToRead{data: getPacket(connID1)} - packetChan <- packetToRead{data: getPacket(connID2)} - - Eventually(handledPacket1).Should(BeClosed()) - Eventually(handledPacket2).Should(BeClosed()) - }) - - It("drops unparseable packets", func() { - addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} - tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: []byte{0, 1, 2, 3}, - }) - }) - - It("deletes removed connections immediately", func() { - handler.deleteRetiredConnsAfter = time.Hour - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - handler.Add(connID, NewMockPacketHandler(mockCtrl)) - handler.Remove(connID) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - // don't EXPECT any calls to handlePacket of the MockPacketHandler - }) - - It("deletes retired connection entries after a wait time", func() { - handler.deleteRetiredConnsAfter = scaleDuration(10 * time.Millisecond) - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - conn := NewMockPacketHandler(mockCtrl) - handler.Add(connID, conn) - handler.Retire(connID) - time.Sleep(scaleDuration(30 * time.Millisecond)) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - // don't EXPECT any calls to handlePacket of the MockPacketHandler - }) - - It("passes packets arriving late for closed connections to that connection", func() { - handler.deleteRetiredConnsAfter = time.Hour - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - packetHandler := NewMockPacketHandler(mockCtrl) - handled := make(chan struct{}) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - close(handled) - }) - handler.Add(connID, packetHandler) - handler.Retire(connID) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - Eventually(handled).Should(BeClosed()) - }) - - It("drops packets for unknown receivers", func() { - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - }) - - It("closes the packet handlers when reading from the conn fails", func() { - done := make(chan struct{}) - packetHandler := NewMockPacketHandler(mockCtrl) - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) { - Expect(e).To(HaveOccurred()) - close(done) - }) - handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler) - packetChan <- packetToRead{err: errors.New("read failed")} - Eventually(done).Should(BeClosed()) - }) - - It("continues listening for temporary errors", func() { - packetHandler := NewMockPacketHandler(mockCtrl) - handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler) - err := deadlineError{} - Expect(err.Temporary()).To(BeTrue()) - packetChan <- packetToRead{err: err} - // don't EXPECT any calls to packetHandler.destroy - time.Sleep(50 * time.Millisecond) - }) - - It("says if a connection ID is already taken", func() { - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) - Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) - }) - - It("says if a connection ID is already taken, for AddWithConnID", func() { - clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - newConnID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) - newConnID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) - Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue()) - Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse()) - }) - }) - - Context("running a server", func() { - It("adds a server", func() { - connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) - p := getPacket(connID) - server := NewMockUnknownPacketHandler(mockCtrl) - server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - cid, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(cid).To(Equal(connID)) - }) - handler.SetServer(server) - handler.handlePacket(&receivedPacket{data: p}) - }) - - It("closes all server connections", func() { - handler.SetServer(NewMockUnknownPacketHandler(mockCtrl)) - clientConn := NewMockPacketHandler(mockCtrl) - clientConn.EXPECT().getPerspective().Return(protocol.PerspectiveClient) - serverConn := NewMockPacketHandler(mockCtrl) - serverConn.EXPECT().getPerspective().Return(protocol.PerspectiveServer) - serverConn.EXPECT().shutdown() - - handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), clientConn) - handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), serverConn) - handler.CloseServer() - }) - - It("stops handling packets with unknown connection IDs after the server is closed", func() { - connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) - p := getPacket(connID) - server := NewMockUnknownPacketHandler(mockCtrl) - // don't EXPECT any calls to server.handlePacket - handler.SetServer(server) - handler.CloseServer() - handler.handlePacket(&receivedPacket{data: p}) - }) - }) - - Context("stateless resets", func() { - BeforeEach(func() { - connIDLen = 5 - }) - - Context("handling", func() { - It("handles stateless resets", func() { - packetHandler := NewMockPacketHandler(mockCtrl) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - destroyed := make(chan struct{}) - packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { - defer GinkgoRecover() - defer close(destroyed) - Expect(err).To(HaveOccurred()) - var resetErr *StatelessResetError - Expect(errors.As(err, &resetErr)).To(BeTrue()) - Expect(err.Error()).To(ContainSubstring("received a stateless reset")) - Expect(resetErr.Token).To(Equal(token)) - }) - packetChan <- packetToRead{data: packet} - Eventually(destroyed).Should(BeClosed()) - }) - - It("handles stateless resets for 0-length connection IDs", func() { - handler.connIDLen = 0 - packetHandler := NewMockPacketHandler(mockCtrl) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - destroyed := make(chan struct{}) - packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { - defer GinkgoRecover() - Expect(err).To(HaveOccurred()) - var resetErr *StatelessResetError - Expect(errors.As(err, &resetErr)).To(BeTrue()) - Expect(err.Error()).To(ContainSubstring("received a stateless reset")) - Expect(resetErr.Token).To(Equal(token)) - close(destroyed) - }) - packetChan <- packetToRead{data: packet} - Eventually(destroyed).Should(BeClosed()) - }) - - It("removes reset tokens", func() { - connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) - packetHandler := NewMockPacketHandler(mockCtrl) - handler.Add(connID, packetHandler) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) - handler.RemoveResetToken(token) - // don't EXPECT any call to packetHandler.destroy() - packetHandler.EXPECT().handlePacket(gomock.Any()) - p := append([]byte{0x40} /* short header packet */, connID.Bytes()...) - p = append(p, make([]byte, 50)...) - p = append(p, token[:]...) - - handler.handlePacket(&receivedPacket{data: p}) - }) - - It("ignores packets too small to contain a stateless reset", func() { - handler.connIDLen = 0 - packetHandler := NewMockPacketHandler(mockCtrl) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - done := make(chan struct{}) - // don't EXPECT any calls here, but register the closing of the done channel - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) { - close(done) - }).AnyTimes() - packetChan <- packetToRead{data: append([]byte{0x40} /* short header packet */, token[:15]...)} - Consistently(done).ShouldNot(BeClosed()) - }) - }) - - Context("generating", func() { - BeforeEach(func() { - var key StatelessResetKey - rand.Read(key[:]) - statelessResetKey = &key - }) - - It("generates stateless reset tokens", func() { - connID1 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) - connID2 := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}) - Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2))) - }) - - It("sends stateless resets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, 100)...) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) { - defer close(done) - Expect(wire.IsLongHeaderPacket(b[0])).To(BeFalse()) // short header packet - Expect(b).To(HaveLen(protocol.MinStatelessResetSize)) - }) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: p, - }) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't send stateless resets for small packets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: p, - }) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }) - }) - - Context("if no key is configured", func() { - It("doesn't send stateless resets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, 100)...) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: p, - }) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }) - }) - }) + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + testErr := errors.New("shutdown") + for i := 0; i < 10; i++ { + conn := NewMockPacketHandler(mockCtrl) + conn.EXPECT().destroy(testErr) + b := make([]byte, 12) + rand.Read(b) + m.Add(protocol.ParseConnectionID(b), conn) + } + m.Close(testErr) + // check that Close can be called multiple times + m.Close(errors.New("close")) }) }) diff --git a/quic_suite_test.go b/quic_suite_test.go index 0eb6f03c3..d979d81bc 100644 --- a/quic_suite_test.go +++ b/quic_suite_test.go @@ -1,8 +1,11 @@ package quic import ( + "bytes" "io" "log" + "runtime/pprof" + "strings" "sync" "testing" @@ -29,6 +32,20 @@ var _ = BeforeSuite(func() { log.SetOutput(io.Discard) }) +func areServersRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "quic-go.(*baseServer).run") +} + +func areTransportsRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "quic-go.(*Transport).listen") +} + var _ = AfterEach(func() { mockCtrl.Finish() + Eventually(areServersRunning).Should(BeFalse()) + Eventually(areTransportsRunning()).Should(BeFalse()) }) diff --git a/server.go b/server.go index d5bb19e6d..f8c9b3cd1 100644 --- a/server.go +++ b/server.go @@ -20,7 +20,7 @@ import ( ) // ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close. -var ErrServerClosed = errors.New("quic: Server closed") +var ErrServerClosed = errors.New("quic: server closed") // packetHandler handles packets type packetHandler interface { @@ -30,18 +30,13 @@ type packetHandler interface { getPerspective() protocol.Perspective } -type unknownPacketHandler interface { - handlePacket(*receivedPacket) - setCloseError(error) -} - type packetHandlerManager interface { Get(protocol.ConnectionID) (packetHandler, bool) + GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool) AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool - Destroy() error - connRunner - SetServer(unknownPacketHandler) + Close(error) CloseServer() + connRunner } type quicConn interface { @@ -70,13 +65,11 @@ type baseServer struct { config *Config conn rawConn - // If the server is started with ListenAddr, we create a packet conn. - // If it is started with Listen, we take a packet conn as a parameter. - createdPacketConn bool tokenGenerator *handshake.TokenGenerator connHandler packetHandlerManager + onClose func() receivedPackets chan *receivedPacket @@ -114,8 +107,6 @@ type baseServer struct { logger utils.Logger } -var _ unknownPacketHandler = &baseServer{} - // A Listener listens for incoming QUIC connections. // It returns connections once the handshake has completed. type Listener struct { @@ -166,37 +157,36 @@ func (l *EarlyListener) Addr() net.Addr { // The tls.Config must not be nil and must contain a certificate configuration. // The quic.Config may be nil, in that case the default values will be used. func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) { - s, err := listenAddr(addr, tlsConf, config, false) + conn, err := listenUDP(addr) if err != nil { return nil, err } - return &Listener{baseServer: s}, nil + return (&Transport{ + Conn: conn, + createdConn: true, + isSingleUse: true, + }).Listen(tlsConf, config) } // ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes. func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) { - s, err := listenAddr(addr, tlsConf, config, true) + conn, err := listenUDP(addr) if err != nil { return nil, err } - return &EarlyListener{baseServer: s}, nil + return (&Transport{ + Conn: conn, + createdConn: true, + isSingleUse: true, + }).ListenEarly(tlsConf, config) } -func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { +func listenUDP(addr string) (*net.UDPConn, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } - conn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - return nil, err - } - serv, err := listen(conn, tlsConf, config, acceptEarly) - if err != nil { - return nil, err - } - serv.createdPacketConn = true - return serv, nil + return net.ListenUDP("udp", udpAddr) } // Listen listens for QUIC connections on a given net.PacketConn. If the @@ -210,45 +200,23 @@ func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bo // Furthermore, it must define an application control (using NextProtos). // The quic.Config may be nil, in that case the default values will be used. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) { - s, err := listen(conn, tlsConf, config, false) - if err != nil { - return nil, err - } - return &Listener{baseServer: s}, nil + tr := &Transport{Conn: conn, isSingleUse: true} + return tr.Listen(tlsConf, config) } // ListenEarly works like Listen, but it returns connections before the handshake completes. func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) { - s, err := listen(conn, tlsConf, config, true) - if err != nil { - return nil, err - } - return &EarlyListener{baseServer: s}, nil + tr := &Transport{Conn: conn, isSingleUse: true} + return tr.ListenEarly(tlsConf, config) } -func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(config); err != nil { - return nil, err - } - config = populateServerConfig(config) - - connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) - if err != nil { - return nil, err - } +func newServer(conn rawConn, connHandler packetHandlerManager, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) { tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) if err != nil { return nil, err } - c, err := wrapConn(conn) - if err != nil { - return nil, err - } s := &baseServer{ - conn: c, + conn: conn, tlsConf: tlsConf, config: config, tokenGenerator: tokenGenerator, @@ -260,12 +228,12 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl newConn: newConnection, logger: utils.DefaultLogger.WithPrefix("server"), acceptEarlyConns: acceptEarly, + onClose: onClose, } if acceptEarly { s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} } go s.run() - connHandler.SetServer(s) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil } @@ -317,18 +285,12 @@ func (s *baseServer) Close() error { if s.serverError == nil { s.serverError = ErrServerClosed } - // If the server was started with ListenAddr, we created the packet conn. - // We need to close it in order to make the go routine reading from that conn return. - createdPacketConn := s.createdPacketConn s.closed = true close(s.errorChan) s.mutex.Unlock() <-s.running - s.connHandler.CloseServer() - if createdPacketConn { - return s.connHandler.Destroy() - } + s.onClose() return nil } diff --git a/server_test.go b/server_test.go index 4108c6989..7f6e49a79 100644 --- a/server_test.go +++ b/server_test.go @@ -1,15 +1,12 @@ package quic import ( - "bytes" "context" "crypto/rand" "crypto/tls" "errors" "net" "reflect" - "runtime/pprof" - "strings" "sync" "sync/atomic" "time" @@ -24,17 +21,10 @@ import ( "github.com/quic-go/quic-go/logging" "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) -func areServersRunning() bool { - var b bytes.Buffer - pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "quic-go.(*baseServer).run") -} - var _ = Describe("Server", func() { var ( conn *MockPacketConn @@ -96,15 +86,19 @@ var _ = Describe("Server", func() { BeforeEach(func() { conn = NewMockPacketConn(mockCtrl) conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() - conn.EXPECT().ReadFrom(gomock.Any()).Do(func(_ []byte) { <-(make(chan struct{})) }).MaxTimes(1) + wait := make(chan struct{}) + conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) { + <-wait + return 0, nil, errors.New("done") + }).MaxTimes(1) + conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) { + close(wait) + conn.EXPECT().SetReadDeadline(time.Time{}) + }).MaxTimes(1) tlsConf = testdata.GetTLSConfig() tlsConf.NextProtos = []string{"proto1"} }) - AfterEach(func() { - Eventually(areServersRunning).Should(BeFalse()) - }) - It("errors when no tls.Config is given", func() { _, err := ListenAddr("localhost:0", nil, nil) Expect(err).To(HaveOccurred()) @@ -178,6 +172,7 @@ var _ = Describe("Server", func() { Context("server accepting connections that completed the handshake", func() { var ( + ln *Listener serv *baseServer phm *MockPacketHandlerManager tracer *mocklogging.MockTracer @@ -185,7 +180,8 @@ var _ = Describe("Server", func() { BeforeEach(func() { tracer = mocklogging.NewMockTracer(mockCtrl) - ln, err := Listen(conn, tlsConf, &Config{Tracer: tracer}) + var err error + ln, err = Listen(conn, tlsConf, &Config{Tracer: tracer}) Expect(err).ToNot(HaveOccurred()) serv = ln.baseServer phm = NewMockPacketHandlerManager(mockCtrl) @@ -193,8 +189,7 @@ var _ = Describe("Server", func() { }) AfterEach(func() { - phm.EXPECT().CloseServer().MaxTimes(1) - serv.Close() + ln.Close() }) Context("handling packets", func() { @@ -753,8 +748,7 @@ var _ = Describe("Server", func() { Consistently(done).ShouldNot(BeClosed()) // make the go routine return - phm.EXPECT().CloseServer() - conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID + conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID Expect(serv.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -968,6 +962,7 @@ var _ = Describe("Server", func() { serv.setCloseError(testErr) Eventually(done).Should(BeClosed()) + serv.onClose() // shutdown }) It("returns immediately, if an error occurred before", func() { @@ -977,6 +972,7 @@ var _ = Describe("Server", func() { _, err := serv.Accept(context.Background()) Expect(err).To(MatchError(testErr)) } + serv.onClose() // shutdown }) It("returns when the context is canceled", func() { @@ -1064,7 +1060,6 @@ var _ = Describe("Server", func() { }) AfterEach(func() { - phm.EXPECT().CloseServer().MaxTimes(1) serv.Close() }) @@ -1234,8 +1229,7 @@ var _ = Describe("Server", func() { Consistently(done).ShouldNot(BeClosed()) // make the go routine return - phm.EXPECT().CloseServer() - conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID + conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID Expect(serv.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) diff --git a/transport.go b/transport.go new file mode 100644 index 000000000..2e860ee27 --- /dev/null +++ b/transport.go @@ -0,0 +1,410 @@ +package quic + +import ( + "context" + "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "log" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/quic-go/quic-go/internal/wire" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" +) + +type Transport struct { + // A single net.PacketConn can only be handled by one Transport. + // Bad things will happen if passed to multiple Transports. + // + // If the connection satisfies the OOBCapablePacketConn interface + // (as a net.UDPConn does), ECN and packet info support will be enabled. + // In this case, optimized syscalls might be used, skipping the + // ReadFrom and WriteTo calls to read / write packets. + Conn net.PacketConn + + // The length of the connection ID in bytes. + // It can be 0, or any value between 4 and 18. + // If unset, a 4 byte connection ID will be used. + ConnectionIDLength int + + // Use for generating new connection IDs. + // This allows the application to control of the connection IDs used, + // which allows routing / load balancing based on connection IDs. + // All Connection IDs returned by the ConnectionIDGenerator MUST + // have the same length. + ConnectionIDGenerator ConnectionIDGenerator + + // The StatelessResetKey is used to generate stateless reset tokens. + // If no key is configured, sending of stateless resets is disabled. + StatelessResetKey *StatelessResetKey + + // A Tracer traces events that don't belong to a single QUIC connection. + Tracer logging.Tracer + + handlerMap packetHandlerManager + + mutex sync.Mutex + initOnce sync.Once + initErr error + + // Set in init. + // If no ConnectionIDGenerator is set, this is the ConnectionIDLength. + connIDLen int + + server unknownPacketHandler + + conn rawConn + + closeQueue chan closePacket + + listening chan struct{} // is closed when listen returns + closed bool + createdConn bool + isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial + + logger utils.Logger +} + +// Listen starts listening for incoming QUIC connections. +// There can only be a single listener on any net.PacketConn. +// Listen may only be called again after the current Listener was closed. +func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + if err := validateConfig(conf); err != nil { + return nil, err + } + + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.server != nil { + return nil, errListenerAlreadySet + } + conf = populateServerConfig(conf) + if err := t.init(conf); err != nil { + return nil, err + } + s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, false) + if err != nil { + return nil, err + } + t.server = s + return &Listener{baseServer: s}, nil +} + +// ListenEarly starts listening for incoming QUIC connections. +// There can only be a single listener on any net.PacketConn. +// Listen may only be called again after the current Listener was closed. +func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + if err := validateConfig(conf); err != nil { + return nil, err + } + + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.server != nil { + return nil, errListenerAlreadySet + } + conf = populateServerConfig(conf) + if err := t.init(conf); err != nil { + return nil, err + } + s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, true) + if err != nil { + return nil, err + } + t.server = s + return &EarlyListener{baseServer: s}, nil +} + +// Dial dials a new connection to a remote host (not using 0-RTT). +func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { + if err := validateConfig(conf); err != nil { + return nil, err + } + conf = populateClientConfig(conf, t.createdConn) + if err := t.init(conf); err != nil { + return nil, err + } + var onClose func() + if t.isSingleUse { + onClose = func() { t.Close() } + } + return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn) +} + +// DialEarly dials a new connection, attempting to use 0-RTT if possible. +func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + if err := validateConfig(conf); err != nil { + return nil, err + } + conf = populateClientConfig(conf, t.createdConn) + if err := t.init(conf); err != nil { + return nil, err + } + var onClose func() + if t.isSingleUse { + onClose = func() { t.Close() } + } + return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn) +} + +func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { + conn, ok := c.(interface{ SetReadBuffer(int) error }) + if !ok { + return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") + } + size, err := inspectReadBuffer(c) + if err != nil { + return fmt.Errorf("failed to determine receive buffer size: %w", err) + } + if size >= protocol.DesiredReceiveBufferSize { + logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) + return nil + } + if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { + return fmt.Errorf("failed to increase receive buffer size: %w", err) + } + newSize, err := inspectReadBuffer(c) + if err != nil { + return fmt.Errorf("failed to determine receive buffer size: %w", err) + } + if newSize == size { + return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024) + } + if newSize < protocol.DesiredReceiveBufferSize { + return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024) + } + logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024) + return nil +} + +// only print warnings about the UDP receive buffer size once +var receiveBufferWarningOnce sync.Once + +func (t *Transport) init(conf *Config) error { + t.initOnce.Do(func() { + getMultiplexer().AddConn(t.Conn) + + conn, err := wrapConn(t.Conn) + if err != nil { + t.initErr = err + return + } + + t.StatelessResetKey = conf.StatelessResetKey + t.Tracer = conf.Tracer + t.ConnectionIDLength = conf.ConnectionIDLength + t.ConnectionIDGenerator = conf.ConnectionIDGenerator + + t.logger = utils.DefaultLogger // TODO: make this configurable + t.conn = conn + t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) + t.listening = make(chan struct{}) + + t.closeQueue = make(chan closePacket, 4) + + if t.ConnectionIDGenerator != nil { + t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen() + } else { + t.connIDLen = t.ConnectionIDLength + } + + go t.listen(conn) + go t.runCloseQueue() + }) + return t.initErr +} + +func (t *Transport) enqueueClosePacket(p closePacket) { + select { + case t.closeQueue <- p: + default: + // Oops, we're backlogged. + // Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway. + } +} + +func (t *Transport) runCloseQueue() { + for { + select { + case <-t.listening: + return + case p := <-t.closeQueue: + t.conn.WritePacket(p.payload, p.addr, p.info.OOB()) + } + } +} + +// Close closes the underlying connection and waits until listen has returned. +// It is invalid to start new listeners or connections after that. +func (t *Transport) Close() error { + t.close(errors.New("closing")) + if t.createdConn { + if err := t.conn.Close(); err != nil { + return err + } + } else { + t.conn.SetReadDeadline(time.Now()) + defer func() { t.conn.SetReadDeadline(time.Time{}) }() + } + <-t.listening // wait until listening returns + return nil +} + +func (t *Transport) closeServer() { + t.handlerMap.CloseServer() + t.mutex.Lock() + t.server = nil + if t.isSingleUse { + t.closed = true + } + t.mutex.Unlock() + if t.createdConn { + t.Conn.Close() + } + if t.isSingleUse { + t.conn.SetReadDeadline(time.Now()) + defer func() { t.conn.SetReadDeadline(time.Time{}) }() + <-t.listening // wait until listening returns + } +} + +func (t *Transport) close(e error) { + t.mutex.Lock() + defer t.mutex.Unlock() + if t.closed { + return + } + + t.handlerMap.Close(e) + if t.server != nil { + t.server.setCloseError(e) + } + t.closed = true +} + +func (t *Transport) listen(conn rawConn) { + defer close(t.listening) + defer getMultiplexer().RemoveConn(t.Conn) + + if err := setReceiveBuffer(t.Conn, t.logger); err != nil { + if !strings.Contains(err.Error(), "use of closed network connection") { + receiveBufferWarningOnce.Do(func() { + if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { + return + } + log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err) + }) + } + } + + for { + p, err := conn.ReadPacket() + //nolint:staticcheck // SA1019 ignore this! + // TODO: This code is used to ignore wsa errors on Windows. + // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. + // See https://github.com/quic-go/quic-go/issues/1737 for details. + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + t.mutex.Lock() + closed := t.closed + t.mutex.Unlock() + if closed { + return + } + t.logger.Debugf("Temporary error reading from conn: %w", err) + continue + } + if err != nil { + t.close(err) + return + } + t.handlePacket(p) + } +} + +func (t *Transport) handlePacket(p *receivedPacket) { + connID, err := wire.ParseConnectionID(p.data, t.connIDLen) + if err != nil { + t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) + if t.Tracer != nil { + t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + } + p.buffer.MaybeRelease() + return + } + + if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset { + return + } + if handler, ok := t.handlerMap.Get(connID); ok { + handler.handlePacket(p) + return + } + if !wire.IsLongHeaderPacket(p.data[0]) { + go t.maybeSendStatelessReset(p, connID) + return + } + + t.mutex.Lock() + defer t.mutex.Unlock() + if t.server == nil { // no server set + t.logger.Debugf("received a packet with an unexpected connection ID %s", connID) + return + } + t.server.handlePacket(p) +} + +func (t *Transport) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { + defer p.buffer.Release() + if t.StatelessResetKey == nil { + return + } + // Don't send a stateless reset in response to very small packets. + // This includes packets that could be stateless resets. + if len(p.data) <= protocol.MinStatelessResetSize { + return + } + token := t.handlerMap.GetStatelessResetToken(connID) + t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) + data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) + rand.Read(data) + data[0] = (data[0] & 0x7f) | 0x40 + data = append(data, token[:]...) + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + t.logger.Debugf("Error sending Stateless Reset: %s", err) + } +} + +func (t *Transport) maybeHandleStatelessReset(data []byte) bool { + // stateless resets are always short header packets + if wire.IsLongHeaderPacket(data[0]) { + return false + } + if len(data) < 17 /* type byte + 16 bytes for the reset token */ { + return false + } + + token := *(*protocol.StatelessResetToken)(data[len(data)-16:]) + if conn, ok := t.handlerMap.GetByResetToken(token); ok { + t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) + go conn.destroy(&StatelessResetError{Token: token}) + return true + } + return false +} diff --git a/transport_test.go b/transport_test.go new file mode 100644 index 000000000..122dbba36 --- /dev/null +++ b/transport_test.go @@ -0,0 +1,287 @@ +package quic + +import ( + "bytes" + "crypto/rand" + "crypto/tls" + "errors" + "net" + "time" + + mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" + + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Transport", func() { + type packetToRead struct { + addr net.Addr + data []byte + err error + } + + getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { + b, err := (&wire.ExtendedHeader{ + Header: wire.Header{ + Type: t, + DestConnectionID: connID, + Length: length, + Version: protocol.Version1, + }, + PacketNumberLen: protocol.PacketNumberLen2, + }).Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + return b + } + + getPacket := func(connID protocol.ConnectionID) []byte { + return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2) + } + + newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn { + conn := NewMockPacketConn(mockCtrl) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() + conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { + p, ok := <-packetChan + if !ok { + return 0, nil, errors.New("closed") + } + return copy(b, p.data), p.addr, p.err + }).AnyTimes() + // for shutdown + conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes() + return conn + } + + It("handles packets for different packet handlers on the same packet conn", func() { + packetChan := make(chan packetToRead) + tr := &Transport{Conn: newMockPacketConn(packetChan)} + tr.init(&Config{}) + phm := NewMockPacketHandlerManager(mockCtrl) + tr.handlerMap = phm + connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) + + handled := make(chan struct{}, 2) + phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { + h := NewMockPacketHandler(mockCtrl) + h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + defer GinkgoRecover() + connID, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(connID1)) + handled <- struct{}{} + }) + return h, true + }) + phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { + h := NewMockPacketHandler(mockCtrl) + h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + defer GinkgoRecover() + connID, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(connID2)) + handled <- struct{}{} + }) + return h, true + }) + + packetChan <- packetToRead{data: getPacket(connID1)} + packetChan <- packetToRead{data: getPacket(connID2)} + + Eventually(handled).Should(Receive()) + Eventually(handled).Should(Receive()) + + // shutdown + phm.EXPECT().Close(gomock.Any()) + close(packetChan) + tr.Close() + }) + + It("closes listeners", func() { + packetChan := make(chan packetToRead) + tr := &Transport{Conn: newMockPacketConn(packetChan)} + defer tr.Close() + ln, err := tr.Listen(&tls.Config{}, nil) + Expect(err).ToNot(HaveOccurred()) + phm := NewMockPacketHandlerManager(mockCtrl) + tr.handlerMap = phm + + phm.EXPECT().CloseServer() + Expect(ln.Close()).To(Succeed()) + + // shutdown + phm.EXPECT().Close(gomock.Any()) + close(packetChan) + tr.Close() + }) + + It("drops unparseable packets", func() { + addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} + packetChan := make(chan packetToRead) + tracer := mocklogging.NewMockTracer(mockCtrl) + tr := &Transport{ + Conn: newMockPacketConn(packetChan), + } + tr.init(&Config{Tracer: tracer, ConnectionIDLength: 10}) + dropped := make(chan struct{}) + tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }) + packetChan <- packetToRead{ + addr: addr, + data: []byte{0, 1, 2, 3}, + } + Eventually(dropped).Should(BeClosed()) + + // shutdown + close(packetChan) + tr.Close() + }) + + It("closes when reading from the conn fails", func() { + packetChan := make(chan packetToRead) + tr := Transport{Conn: newMockPacketConn(packetChan)} + defer tr.Close() + phm := NewMockPacketHandlerManager(mockCtrl) + tr.init(&Config{}) + tr.handlerMap = phm + + done := make(chan struct{}) + phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) }) + packetChan <- packetToRead{err: errors.New("read failed")} + Eventually(done).Should(BeClosed()) + + // shutdown + close(packetChan) + tr.Close() + }) + + It("continues listening after temporary errors", func() { + packetChan := make(chan packetToRead) + tr := Transport{Conn: newMockPacketConn(packetChan)} + defer tr.Close() + phm := NewMockPacketHandlerManager(mockCtrl) + tr.init(&Config{}) + tr.handlerMap = phm + + tempErr := deadlineError{} + Expect(tempErr.Temporary()).To(BeTrue()) + packetChan <- packetToRead{err: tempErr} + // don't expect any calls to phm.Close + time.Sleep(50 * time.Millisecond) + + // shutdown + phm.EXPECT().Close(gomock.Any()) + close(packetChan) + tr.Close() + }) + + It("handles short header packets resets", func() { + connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) + packetChan := make(chan packetToRead) + tr := Transport{Conn: newMockPacketConn(packetChan)} + tr.init(&Config{ConnectionIDLength: connID.Len()}) + defer tr.Close() + phm := NewMockPacketHandlerManager(mockCtrl) + tr.init(&Config{}) + tr.handlerMap = phm + + var token protocol.StatelessResetToken + rand.Read(token[:]) + + var b []byte + b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) + Expect(err).ToNot(HaveOccurred()) + b = append(b, token[:]...) + conn := NewMockPacketHandler(mockCtrl) + gomock.InOrder( + phm.EXPECT().GetByResetToken(token), + phm.EXPECT().Get(connID).Return(conn, true), + conn.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + Expect(p.data).To(Equal(b)) + Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second)) + }), + ) + packetChan <- packetToRead{data: b} + + // shutdown + phm.EXPECT().Close(gomock.Any()) + close(packetChan) + tr.Close() + }) + + It("handles stateless resets", func() { + connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) + packetChan := make(chan packetToRead) + tr := Transport{Conn: newMockPacketConn(packetChan)} + tr.init(&Config{ConnectionIDLength: connID.Len()}) + defer tr.Close() + phm := NewMockPacketHandlerManager(mockCtrl) + tr.init(&Config{}) + tr.handlerMap = phm + + var token protocol.StatelessResetToken + rand.Read(token[:]) + + var b []byte + b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) + Expect(err).ToNot(HaveOccurred()) + b = append(b, token[:]...) + conn := NewMockPacketHandler(mockCtrl) + gomock.InOrder( + phm.EXPECT().GetByResetToken(token).Return(conn, true), + conn.EXPECT().destroy(gomock.Any()).Do(func(err error) { + Expect(err).To(MatchError(&StatelessResetError{Token: token})) + }), + ) + packetChan <- packetToRead{data: b} + + // shutdown + phm.EXPECT().Close(gomock.Any()) + close(packetChan) + tr.Close() + }) + + It("sends stateless resets", func() { + connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) + packetChan := make(chan packetToRead) + conn := newMockPacketConn(packetChan) + tr := Transport{ + Conn: conn, + } + tr.init(&Config{ConnectionIDLength: connID.Len(), StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}}) + defer tr.Close() + phm := NewMockPacketHandlerManager(mockCtrl) + tr.init(&Config{}) + tr.handlerMap = phm + + var b []byte + b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) + Expect(err).ToNot(HaveOccurred()) + b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...) + + var token protocol.StatelessResetToken + rand.Read(token[:]) + written := make(chan struct{}) + gomock.InOrder( + phm.EXPECT().GetByResetToken(gomock.Any()), + phm.EXPECT().Get(connID), + phm.EXPECT().GetStatelessResetToken(connID).Return(token), + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) { + defer close(written) + Expect(bytes.Contains(b, token[:])).To(BeTrue()) + }), + ) + packetChan <- packetToRead{data: b} + Eventually(written).Should(BeClosed()) + + // shutdown + phm.EXPECT().Close(gomock.Any()) + close(packetChan) + tr.Close() + }) +}) From b79b532b04144e2d8bff48597f89bb7ceddb86d7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 6 Apr 2023 19:07:37 +0800 Subject: [PATCH 3/8] remove StatelessResetKey from the Config, it's now on the Transport --- client_test.go | 3 --- config.go | 1 - integrationtests/self/stateless_reset_test.go | 8 ++++---- interface.go | 3 --- server_test.go | 2 -- transport.go | 1 - transport_test.go | 5 +++-- 7 files changed, 7 insertions(+), 16 deletions(-) diff --git a/client_test.go b/client_test.go index ec8b83b8b..71e16f0b4 100644 --- a/client_test.go +++ b/client_test.go @@ -213,7 +213,6 @@ var _ = Describe("Client", func() { Context("quic.Config", func() { It("setups with the right values", func() { - srk := &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'} tokenStore := NewLRUTokenStore(10, 4) config := &Config{ HandshakeIdleTimeout: 1337 * time.Minute, @@ -221,7 +220,6 @@ var _ = Describe("Client", func() { MaxIncomingStreams: 1234, MaxIncomingUniStreams: 4321, ConnectionIDLength: 13, - StatelessResetKey: srk, TokenStore: tokenStore, EnableDatagrams: true, } @@ -231,7 +229,6 @@ var _ = Describe("Client", func() { Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) Expect(c.ConnectionIDLength).To(Equal(13)) - Expect(c.StatelessResetKey).To(Equal(srk)) Expect(c.TokenStore).To(Equal(tokenStore)) Expect(c.EnableDatagrams).To(BeTrue()) }) diff --git a/config.go b/config.go index b513a46b3..7a48b4073 100644 --- a/config.go +++ b/config.go @@ -137,7 +137,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { MaxIncomingUniStreams: maxIncomingUniStreams, ConnectionIDLength: conIDLen, ConnectionIDGenerator: connIDGenerator, - StatelessResetKey: config.StatelessResetKey, TokenStore: config.TokenStore, EnableDatagrams: config.EnableDatagrams, DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index cc5afe903..3c9a17034 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -24,15 +24,15 @@ var _ = Describe("Stateless Resets", func() { It(fmt.Sprintf("sends and recognizes stateless resets, for %d byte connection IDs", connIDLen), func() { var statelessResetKey quic.StatelessResetKey rand.Read(statelessResetKey[:]) - serverConfig := getQuicConfig(&quic.Config{StatelessResetKey: &statelessResetKey}) c, err := net.ListenUDP("udp", nil) Expect(err).ToNot(HaveOccurred()) tr := &quic.Transport{ - Conn: c, + Conn: c, + StatelessResetKey: &statelessResetKey, } defer tr.Close() - ln, err := tr.Listen(getTLSConfig(), serverConfig) + ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) serverPort := ln.Addr().(*net.UDPAddr).Port @@ -90,7 +90,7 @@ var _ = Describe("Stateless Resets", func() { StatelessResetKey: &statelessResetKey, } defer tr2.Close() - ln2, err := tr2.Listen(getTLSConfig(), serverConfig) + ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) drop.Store(false) diff --git a/interface.go b/interface.go index 43101448d..33309facf 100644 --- a/interface.go +++ b/interface.go @@ -314,9 +314,6 @@ type Config struct { // If not set, it will default to 100. // If set to a negative value, it doesn't allow any unidirectional streams. MaxIncomingUniStreams int64 - // The StatelessResetKey is used to generate stateless reset tokens. - // If no key is configured, sending of stateless resets is disabled. - StatelessResetKey *StatelessResetKey // KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive. // If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most // every half of MaxIdleTimeout, whichever is smaller). diff --git a/server_test.go b/server_test.go index 7f6e49a79..658bb45bd 100644 --- a/server_test.go +++ b/server_test.go @@ -132,7 +132,6 @@ var _ = Describe("Server", func() { HandshakeIdleTimeout: 1337 * time.Hour, MaxIdleTimeout: 42 * time.Minute, KeepAlivePeriod: 5 * time.Second, - StatelessResetKey: &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'}, RequireAddressValidation: requireAddrVal, } ln, err := Listen(conn, tlsConf, &config) @@ -144,7 +143,6 @@ var _ = Describe("Server", func() { Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal))) Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) - Expect(server.config.StatelessResetKey).To(Equal(&StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'})) // stop the listener Expect(ln.Close()).To(Succeed()) }) diff --git a/transport.go b/transport.go index 2e860ee27..bf902fbc5 100644 --- a/transport.go +++ b/transport.go @@ -207,7 +207,6 @@ func (t *Transport) init(conf *Config) error { return } - t.StatelessResetKey = conf.StatelessResetKey t.Tracer = conf.Tracer t.ConnectionIDLength = conf.ConnectionIDLength t.ConnectionIDGenerator = conf.ConnectionIDGenerator diff --git a/transport_test.go b/transport_test.go index 122dbba36..c404a3b92 100644 --- a/transport_test.go +++ b/transport_test.go @@ -251,9 +251,10 @@ var _ = Describe("Transport", func() { packetChan := make(chan packetToRead) conn := newMockPacketConn(packetChan) tr := Transport{ - Conn: conn, + Conn: conn, + StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, } - tr.init(&Config{ConnectionIDLength: connID.Len(), StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}}) + tr.init(&Config{ConnectionIDLength: connID.Len()}) defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) tr.init(&Config{}) From ba942715dbada3d26b80ac375836cbcefae4ec51 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 20 Apr 2023 11:37:56 +0200 Subject: [PATCH 4/8] remove ConnectionIDLength and ConnectionIDGenerator from the Config --- client.go | 23 ++++-- client_test.go | 31 ++++--- config.go | 26 +----- config_test.go | 17 +--- connection.go | 6 +- connection_test.go | 6 +- integrationtests/self/conn_id_test.go | 81 +++++++++++-------- integrationtests/self/handshake_test.go | 4 +- integrationtests/self/mitm_test.go | 42 ++++++---- integrationtests/self/multiplex_test.go | 1 - integrationtests/self/stateless_reset_test.go | 29 ++++--- integrationtests/self/zero_rtt_test.go | 46 ++++++++--- interface.go | 12 --- server.go | 14 ++-- server_test.go | 11 +++ transport.go | 36 +++++---- transport_test.go | 30 +++---- 17 files changed, 232 insertions(+), 183 deletions(-) diff --git a/client.go b/client.go index c8ea0641a..ed6ccfb82 100644 --- a/client.go +++ b/client.go @@ -25,8 +25,9 @@ type client struct { tlsConf *tls.Config config *Config - srcConnID protocol.ConnectionID - destConnID protocol.ConnectionID + connIDGenerator ConnectionIDGenerator + srcConnID protocol.ConnectionID + destConnID protocol.ConnectionID initialPacketNumber protocol.PacketNumber hasNegotiatedVersion bool @@ -133,6 +134,7 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo func dial( ctx context.Context, conn net.PacketConn, + connIDGenerator ConnectionIDGenerator, packetHandlers packetHandlerManager, addr net.Addr, tlsConf *tls.Config, @@ -141,7 +143,7 @@ func dial( use0RTT bool, createdPacketConn bool, ) (quicConn, error) { - c, err := newClient(conn, addr, config, tlsConf, onClose, use0RTT, createdPacketConn) + c, err := newClient(conn, addr, connIDGenerator, config, tlsConf, onClose, use0RTT, createdPacketConn) if err != nil { return nil, err } @@ -164,14 +166,23 @@ func dial( return c.conn, nil } -func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, onClose func(), use0RTT, createdPacketConn bool) (*client, error) { +func newClient( + pconn net.PacketConn, + remoteAddr net.Addr, + connIDGenerator ConnectionIDGenerator, + config *Config, + tlsConf *tls.Config, + onClose func(), + use0RTT bool, + createdPacketConn bool, +) (*client, error) { if tlsConf == nil { tlsConf = &tls.Config{} } else { tlsConf = tlsConf.Clone() } - srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID() + srcConnID, err := connIDGenerator.GenerateConnectionID() if err != nil { return nil, err } @@ -180,6 +191,7 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon return nil, err } c := &client{ + connIDGenerator: connIDGenerator, srcConnID: srcConnID, destConnID: destConnID, sconn: newSendPconn(pconn, remoteAddr), @@ -203,6 +215,7 @@ func (c *client) dial(ctx context.Context) error { c.packetHandlers, c.destConnID, c.srcConnID, + c.connIDGenerator, c.config, c.tlsConf, c.initialPacketNumber, diff --git a/client_test.go b/client_test.go index 71e16f0b4..ce53ef4b9 100644 --- a/client_test.go +++ b/client_test.go @@ -39,6 +39,7 @@ var _ = Describe("Client", func() { runner connRunner, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + connIDGenerator ConnectionIDGenerator, conf *Config, tlsConf *tls.Config, initialPacketNumber protocol.PacketNumber, @@ -114,6 +115,7 @@ var _ = Describe("Client", func() { _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ *Config, _ *tls.Config, _ protocol.PacketNumber, @@ -132,7 +134,7 @@ var _ = Describe("Client", func() { conn.EXPECT().HandshakeComplete().Return(c) return conn } - cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, false, false) + cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false, false) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -150,6 +152,7 @@ var _ = Describe("Client", func() { runner connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ *Config, _ *tls.Config, _ protocol.PacketNumber, @@ -168,7 +171,7 @@ var _ = Describe("Client", func() { return conn } - cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, true, false) + cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true, false) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -186,6 +189,7 @@ var _ = Describe("Client", func() { _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ *Config, _ *tls.Config, _ protocol.PacketNumber, @@ -203,7 +207,7 @@ var _ = Describe("Client", func() { return conn } var closed bool - cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, func() { closed = true }, true, false) + cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true, false) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -219,16 +223,14 @@ var _ = Describe("Client", func() { MaxIdleTimeout: 42 * time.Hour, MaxIncomingStreams: 1234, MaxIncomingUniStreams: 4321, - ConnectionIDLength: 13, TokenStore: tokenStore, EnableDatagrams: true, } - c := populateClientConfig(config, false) + c := populateConfig(config) Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute)) Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour)) Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) - Expect(c.ConnectionIDLength).To(Equal(13)) Expect(c.TokenStore).To(Equal(tokenStore)) Expect(c.EnableDatagrams).To(BeTrue()) }) @@ -238,7 +240,7 @@ var _ = Describe("Client", func() { MaxIncomingStreams: -1, MaxIncomingUniStreams: 4321, } - c := populateClientConfig(config, false) + c := populateConfig(config) Expect(c.MaxIncomingStreams).To(BeZero()) Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) }) @@ -248,18 +250,13 @@ var _ = Describe("Client", func() { MaxIncomingStreams: 1234, MaxIncomingUniStreams: -1, } - c := populateClientConfig(config, false) + c := populateConfig(config) Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) Expect(c.MaxIncomingUniStreams).To(BeZero()) }) - It("uses 0-byte connection IDs when dialing an address", func() { - c := populateClientConfig(&Config{}, true) - Expect(c.ConnectionIDLength).To(BeZero()) - }) - It("fills in default values if options are not set in the Config", func() { - c := populateClientConfig(&Config{}, false) + c := populateConfig(&Config{}) Expect(c.Versions).To(Equal(protocol.SupportedVersions)) Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) @@ -267,7 +264,7 @@ var _ = Describe("Client", func() { }) It("creates new connections with the right parameters", func() { - config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}} + config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}} c := make(chan struct{}) var cconn sendConn var version protocol.VersionNumber @@ -278,6 +275,7 @@ var _ = Describe("Client", func() { _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, configP *Config, _ *tls.Config, _ protocol.PacketNumber, @@ -320,6 +318,7 @@ var _ = Describe("Client", func() { runner connRunner, _ protocol.ConnectionID, connID protocol.ConnectionID, + _ ConnectionIDGenerator, configP *Config, _ *tls.Config, pn protocol.PacketNumber, @@ -352,7 +351,7 @@ var _ = Describe("Client", func() { return conn } - config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}} + config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}} tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) _, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config) Expect(err).ToNot(HaveOccurred()) diff --git a/config.go b/config.go index 7a48b4073..ceb4fb682 100644 --- a/config.go +++ b/config.go @@ -42,7 +42,7 @@ func validateConfig(config *Config) error { // populateServerConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil func populateServerConfig(config *Config) *Config { - config = populateConfig(config, protocol.DefaultConnectionIDLength) + config = populateConfig(config) if config.MaxTokenAge == 0 { config.MaxTokenAge = protocol.TokenValidity } @@ -55,19 +55,9 @@ func populateServerConfig(config *Config) *Config { return config } -// populateClientConfig populates fields in the quic.Config with their default values, if none are set +// populateConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil -func populateClientConfig(config *Config, createdPacketConn bool) *Config { - defaultConnIDLen := protocol.DefaultConnectionIDLength - if createdPacketConn { - defaultConnIDLen = 0 - } - - config = populateConfig(config, defaultConnIDLen) - return config -} - -func populateConfig(config *Config, defaultConnIDLen int) *Config { +func populateConfig(config *Config) *Config { if config == nil { config = &Config{} } @@ -75,10 +65,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { if len(versions) == 0 { versions = protocol.SupportedVersions } - conIDLen := config.ConnectionIDLength - if config.ConnectionIDLength == 0 { - conIDLen = defaultConnIDLen - } handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout if config.HandshakeIdleTimeout != 0 { handshakeIdleTimeout = config.HandshakeIdleTimeout @@ -115,10 +101,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { } else if maxIncomingUniStreams < 0 { maxIncomingUniStreams = 0 } - connIDGenerator := config.ConnectionIDGenerator - if connIDGenerator == nil { - connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen} - } return &Config{ Versions: versions, @@ -135,8 +117,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, - ConnectionIDLength: conIDLen, - ConnectionIDGenerator: connIDGenerator, TokenStore: config.TokenStore, EnableDatagrams: config.EnableDatagrams, DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, diff --git a/config_test.go b/config_test.go index b9ca6a5cd..f319deb26 100644 --- a/config_test.go +++ b/config_test.go @@ -142,18 +142,18 @@ var _ = Describe("Config", func() { var calledAddrValidation bool c1 := &Config{} c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true } - c2 := populateConfig(c1, protocol.DefaultConnectionIDLength) + c2 := populateConfig(c1) c2.RequireAddressValidation(&net.UDPAddr{}) Expect(calledAddrValidation).To(BeTrue()) }) It("copies non-function fields", func() { c := configWithNonZeroNonFunctionFields() - Expect(populateConfig(c, protocol.DefaultConnectionIDLength)).To(Equal(c)) + Expect(populateConfig(c)).To(Equal(c)) }) It("populates empty fields with default values", func() { - c := populateConfig(&Config{}, protocol.DefaultConnectionIDLength) + c := populateConfig(&Config{}) Expect(c.Versions).To(Equal(protocol.SupportedVersions)) Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData)) @@ -168,18 +168,7 @@ var _ = Describe("Config", func() { It("populates empty fields with default values, for the server", func() { c := populateServerConfig(&Config{}) - Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) Expect(c.RequireAddressValidation).ToNot(BeNil()) }) - - It("sets a default connection ID length if we didn't create the conn, for the client", func() { - c := populateClientConfig(&Config{}, false) - Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) - }) - - It("doesn't set a default connection ID length if we created the conn, for the client", func() { - c := populateClientConfig(&Config{}, true) - Expect(c.ConnectionIDLength).To(BeZero()) - }) }) }) diff --git a/connection.go b/connection.go index eb16ece5d..78969ed91 100644 --- a/connection.go +++ b/connection.go @@ -240,6 +240,7 @@ var newConnection = func( clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + connIDGenerator ConnectionIDGenerator, statelessResetToken protocol.StatelessResetToken, conf *Config, tlsConf *tls.Config, @@ -283,7 +284,7 @@ var newConnection = func( runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, - s.config.ConnectionIDGenerator, + connIDGenerator, ) s.preSetup() s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) @@ -363,6 +364,7 @@ var newClientConnection = func( runner connRunner, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + connIDGenerator ConnectionIDGenerator, conf *Config, tlsConf *tls.Config, initialPacketNumber protocol.PacketNumber, @@ -402,7 +404,7 @@ var newClientConnection = func( runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, - s.config.ConnectionIDGenerator, + connIDGenerator, ) s.preSetup() s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) diff --git a/connection_test.go b/connection_test.go index a2829e854..10f0c202b 100644 --- a/connection_test.go +++ b/connection_test.go @@ -113,6 +113,7 @@ var _ = Describe("Connection", func() { clientDestConnID, destConnID, srcConnID, + &protocol.DefaultConnectionIDGenerator{}, protocol.StatelessResetToken{}, populateServerConfig(&Config{DisablePathMTUDiscovery: true}), nil, // tls.Config @@ -2015,8 +2016,6 @@ var _ = Describe("Connection", func() { packer.EXPECT().HandleTransportParameters(params) packer.EXPECT().PackCoalescedPacket(false, conn.version).MaxTimes(3) Expect(conn.earlyConnReady()).ToNot(BeClosed()) - connRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2) - connRunner.EXPECT().Add(gomock.Any(), conn).Times(2) tracer.EXPECT().ReceivedTransportParameters(params) conn.handleTransportParameters(params) Expect(conn.earlyConnReady()).To(BeClosed()) @@ -2378,7 +2377,7 @@ var _ = Describe("Client Connection", func() { } BeforeEach(func() { - quicConf = populateClientConfig(&Config{}, true) + quicConf = populateConfig(&Config{}) tlsConf = nil }) @@ -2402,6 +2401,7 @@ var _ = Describe("Client Connection", func() { connRunner, destConnID, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + &protocol.DefaultConnectionIDGenerator{}, quicConf, tlsConf, 42, // initial packet number diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index 7cb1904d3..0d8c4b446 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -34,13 +34,23 @@ func (c *connIDGenerator) ConnectionIDLen() int { var _ = Describe("Connection ID lengths tests", func() { randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) } - runServer := func(conf *quic.Config) *quic.Listener { - if conf.ConnectionIDGenerator != nil { - GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", conf.ConnectionIDGenerator.ConnectionIDLen()))) + // connIDLen is ignored when connIDGenerator is set + runServer := func(connIDLen int, connIDGenerator quic.ConnectionIDGenerator) (*quic.Listener, func()) { + if connIDGenerator != nil { + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", connIDGenerator.ConnectionIDLen()))) } else { - GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength))) + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", connIDLen))) } - ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), conf) + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + tr := &quic.Transport{ + Conn: conn, + ConnectionIDLength: connIDLen, + ConnectionIDGenerator: connIDGenerator, + } + ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() @@ -59,20 +69,35 @@ var _ = Describe("Connection ID lengths tests", func() { }() } }() - return ln + return ln, func() { + ln.Close() + tr.Close() + } } - runClient := func(addr net.Addr, conf *quic.Config) { - if conf.ConnectionIDGenerator != nil { - GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", conf.ConnectionIDGenerator.ConnectionIDLen()))) + // connIDLen is ignored when connIDGenerator is set + runClient := func(addr net.Addr, connIDLen int, connIDGenerator quic.ConnectionIDGenerator) { + if connIDGenerator != nil { + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", connIDGenerator.ConnectionIDLen()))) } else { - GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength))) + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", connIDLen))) } - cl, err := quic.DialAddr( + laddr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn, err := net.ListenUDP("udp", laddr) + Expect(err).ToNot(HaveOccurred()) + defer conn.Close() + tr := &quic.Transport{ + Conn: conn, + ConnectionIDLength: connIDLen, + ConnectionIDGenerator: connIDGenerator, + } + defer tr.Close() + cl, err := tr.Dial( context.Background(), - fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port), + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port}, getTLSClientConfig(), - conf, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) defer cl.CloseWithError(0, "") @@ -84,32 +109,20 @@ var _ = Describe("Connection ID lengths tests", func() { } It("downloads a file using a 0-byte connection ID for the client", func() { - serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()}) - ln := runServer(serverConf) - defer ln.Close() - - runClient(ln.Addr(), getQuicConfig(nil)) + ln, closeFn := runServer(randomConnIDLen(), nil) + defer closeFn() + runClient(ln.Addr(), 0, nil) }) It("downloads a file when both client and server use a random connection ID length", func() { - serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()}) - ln := runServer(serverConf) - defer ln.Close() - - runClient(ln.Addr(), getQuicConfig(nil)) + ln, closeFn := runServer(randomConnIDLen(), nil) + defer closeFn() + runClient(ln.Addr(), randomConnIDLen(), nil) }) It("downloads a file when both client and server use a custom connection ID generator", func() { - serverConf := getQuicConfig(&quic.Config{ - ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()}, - }) - clientConf := getQuicConfig(&quic.Config{ - ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()}, - }) - - ln := runServer(serverConf) - defer ln.Close() - - runClient(ln.Addr(), clientConf) + ln, closeFn := runServer(0, &connIDGenerator{length: randomConnIDLen()}) + defer closeFn() + runClient(ln.Addr(), 0, &connIDGenerator{length: randomConnIDLen()}) }) }) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index b3a13e9d4..fc77f4244 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -244,7 +244,7 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) pconn, err = net.ListenUDP("udp", laddr) Expect(err).ToNot(HaveOccurred()) - dialer = &quic.Transport{Conn: pconn} + dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4} }) AfterEach(func() { @@ -303,7 +303,7 @@ var _ = Describe("Handshake tests", func() { // This should free one spot in the queue. Expect(firstConn.CloseWithError(0, "")) Eventually(firstConn.Context().Done()).Should(BeClosed()) - time.Sleep(scaleDuration(20 * time.Millisecond)) + time.Sleep(scaleDuration(200 * time.Millisecond)) // dial again, and expect that this dial succeeds _, err = dial() diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index 133c6f236..35e0af913 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -35,7 +35,11 @@ var _ = Describe("MITM test", func() { Expect(err).ToNot(HaveOccurred()) serverUDPConn, err = net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) - ln, err := quic.Listen(serverUDPConn, getTLSConfig(), serverConfig) + tr := &quic.Transport{ + Conn: serverUDPConn, + ConnectionIDLength: connIDLen, + } + ln, err := tr.Listen(getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) go func() { @@ -68,7 +72,7 @@ var _ = Describe("MITM test", func() { } BeforeEach(func() { - serverConfig = getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}) + serverConfig = getQuicConfig(nil) addr, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) clientUDPConn, err = net.ListenUDP("udp", addr) @@ -146,12 +150,15 @@ var _ = Describe("MITM test", func() { defer closeFn() raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) - conn, err := quic.Dial( + tr := &quic.Transport{ + Conn: clientUDPConn, + ConnectionIDLength: connIDLen, + } + conn, err := tr.Dial( context.Background(), - clientUDPConn, raddr, getTLSClientConfig(), - getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}), + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) str, err := conn.AcceptUniStream(context.Background()) @@ -190,12 +197,15 @@ var _ = Describe("MITM test", func() { defer closeFn() raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) - conn, err := quic.Dial( + tr := &quic.Transport{ + Conn: clientUDPConn, + ConnectionIDLength: connIDLen, + } + conn, err := tr.Dial( context.Background(), - clientUDPConn, raddr, getTLSClientConfig(), - getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}), + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) str, err := conn.AcceptUniStream(context.Background()) @@ -302,20 +312,20 @@ var _ = Describe("MITM test", func() { const rtt = 20 * time.Millisecond runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) { - proxyPort, closeFn := startServerAndProxy(delayCb, nil) + proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil) raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) - _, err = quic.Dial( + tr := &quic.Transport{ + Conn: clientUDPConn, + ConnectionIDLength: connIDLen, + } + _, err = tr.Dial( context.Background(), - clientUDPConn, raddr, getTLSClientConfig(), - getQuicConfig(&quic.Config{ - ConnectionIDLength: connIDLen, - HandshakeIdleTimeout: 2 * time.Second, - }), + getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}), ) - return closeFn, err + return func() { tr.Close(); serverCloseFn() }, err } // fails immediately because client connection closes when it can't find compatible version diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index dcac1b469..2623e3227 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -137,7 +137,6 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn.Close() tr := &quic.Transport{Conn: conn} - server, err := tr.Listen( getTLSConfig(), getQuicConfig(nil), diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index 3c9a17034..4ceb80677 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -28,8 +28,9 @@ var _ = Describe("Stateless Resets", func() { c, err := net.ListenUDP("udp", nil) Expect(err).ToNot(HaveOccurred()) tr := &quic.Transport{ - Conn: c, - StatelessResetKey: &statelessResetKey, + Conn: c, + StatelessResetKey: &statelessResetKey, + ConnectionIDLength: connIDLen, } defer tr.Close() ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil)) @@ -61,14 +62,21 @@ var _ = Describe("Stateless Resets", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - conn, err := quic.DialAddr( + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + defer udpConn.Close() + cl := &quic.Transport{ + Conn: udpConn, + ConnectionIDLength: connIDLen, + } + defer cl.Close() + conn, err := cl.Dial( context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()}, getTLSClientConfig(), - getQuicConfig(&quic.Config{ - ConnectionIDLength: connIDLen, - MaxIdleTimeout: 2 * time.Second, - }), + getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}), ) Expect(err).ToNot(HaveOccurred()) str, err := conn.AcceptStream(context.Background()) @@ -86,8 +94,9 @@ var _ = Describe("Stateless Resets", func() { // We need to create a new Transport here, since the old one is still sending out // CONNECTION_CLOSE packets for (recently) closed connections). tr2 := &quic.Transport{ - Conn: c, - StatelessResetKey: &statelessResetKey, + Conn: c, + ConnectionIDLength: connIDLen, + StatelessResetKey: &statelessResetKey, } defer tr2.Close() ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil)) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 1559c1760..e83a9f965 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -101,6 +101,7 @@ var _ = Describe("0-RTT", func() { transfer0RTTData := func( ln *quic.EarlyListener, proxyPort int, + connIDLen int, clientTLSConf *tls.Config, clientConf *quic.Config, testdata []byte, // data to transfer @@ -125,13 +126,35 @@ var _ = Describe("0-RTT", func() { if clientConf == nil { clientConf = getQuicConfig(nil) } - conn, err := quic.DialAddrEarly( - context.Background(), - fmt.Sprintf("localhost:%d", proxyPort), - clientTLSConf, - clientConf, - ) - Expect(err).ToNot(HaveOccurred()) + var conn quic.EarlyConnection + if connIDLen == 0 { + var err error + conn, err = quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxyPort), + clientTLSConf, + clientConf, + ) + Expect(err).ToNot(HaveOccurred()) + } else { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + defer udpConn.Close() + tr := &quic.Transport{ + Conn: udpConn, + ConnectionIDLength: connIDLen, + } + defer tr.Close() + conn, err = tr.DialEarly( + context.Background(), + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxyPort}, + clientTLSConf, + clientConf, + ) + Expect(err).ToNot(HaveOccurred()) + } defer conn.CloseWithError(0, "") str, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -212,8 +235,9 @@ var _ = Describe("0-RTT", func() { transfer0RTTData( ln, proxy.LocalPort(), + connIDLen, clientTLSConf, - getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}), + getQuicConfig(nil), PRData, ) @@ -373,7 +397,7 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData) + transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) num0RTT := atomic.LoadUint32(&num0RTTPackets) numDropped := atomic.LoadUint32(&num0RTTDropped) @@ -448,7 +472,7 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, GeneratePRData(5000)) // ~5 packets + transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, GeneratePRData(5000)) // ~5 packets mutex.Lock() defer mutex.Unlock() @@ -768,7 +792,7 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData) + transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) diff --git a/interface.go b/interface.go index 33309facf..267c07f5c 100644 --- a/interface.go +++ b/interface.go @@ -242,18 +242,6 @@ type Config struct { // The QUIC versions that can be negotiated. // If not set, it uses all versions available. Versions []VersionNumber - // The length of the connection ID in bytes. - // It can be 0, or any value between 4 and 18. - // If not set, the interpretation depends on where the Config is used: - // If used for dialing an address, a 0 byte connection ID will be used. - // If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used. - // When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call. - ConnectionIDLength int - // An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection. - // The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers. - // By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used. - // Otherwise, if one is provided, then ConnectionIDLength is ignored. - ConnectionIDGenerator ConnectionIDGenerator // HandshakeIdleTimeout is the idle timeout before completion of the handshake. // Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. // If this value is zero, the timeout is set to 5 seconds. diff --git a/server.go b/server.go index f8c9b3cd1..a80709ea3 100644 --- a/server.go +++ b/server.go @@ -68,8 +68,9 @@ type baseServer struct { tokenGenerator *handshake.TokenGenerator - connHandler packetHandlerManager - onClose func() + connIDGenerator ConnectionIDGenerator + connHandler packetHandlerManager + onClose func() receivedPackets chan *receivedPacket @@ -85,6 +86,7 @@ type baseServer struct { protocol.ConnectionID, /* client dest connection ID */ protocol.ConnectionID, /* destination connection ID */ protocol.ConnectionID, /* source connection ID */ + ConnectionIDGenerator, protocol.StatelessResetToken, *Config, *tls.Config, @@ -210,7 +212,7 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear return tr.ListenEarly(tlsConf, config) } -func newServer(conn rawConn, connHandler packetHandlerManager, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) { +func newServer(conn rawConn, connHandler packetHandlerManager, connIDGenerator ConnectionIDGenerator, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) { tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) if err != nil { return nil, err @@ -220,6 +222,7 @@ func newServer(conn rawConn, connHandler packetHandlerManager, tlsConf *tls.Conf tlsConf: tlsConf, config: config, tokenGenerator: tokenGenerator, + connIDGenerator: connIDGenerator, connHandler: connHandler, connQueue: make(chan quicConn), errorChan: make(chan struct{}), @@ -574,7 +577,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro return nil } - connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() + connID, err := s.connIDGenerator.GenerateConnectionID() if err != nil { return err } @@ -603,6 +606,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro hdr.DestConnectionID, hdr.SrcConnectionID, connID, + s.connIDGenerator, s.connHandler.GetStatelessResetToken(connID), s.config, s.tlsConf, @@ -669,7 +673,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the connection. (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) - srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() + srcConnID, err := s.connIDGenerator.GenerateConnectionID() if err != nil { return err } diff --git a/server_test.go b/server_test.go index 658bb45bd..7a17ffd51 100644 --- a/server_test.go +++ b/server_test.go @@ -286,6 +286,7 @@ var _ = Describe("Server", func() { clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + _ ConnectionIDGenerator, tokenP protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -488,6 +489,7 @@ var _ = Describe("Server", func() { clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + _ ConnectionIDGenerator, tokenP protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -547,6 +549,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -600,6 +603,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -631,6 +635,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -702,6 +707,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -1009,6 +1015,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -1082,6 +1089,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -1124,6 +1132,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -1187,6 +1196,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, @@ -1309,6 +1319,7 @@ var _ = Describe("Server", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ ConnectionIDGenerator, _ protocol.StatelessResetToken, _ *Config, _ *tls.Config, diff --git a/transport.go b/transport.go index bf902fbc5..ebb144f41 100644 --- a/transport.go +++ b/transport.go @@ -59,6 +59,9 @@ type Transport struct { // Set in init. // If no ConnectionIDGenerator is set, this is the ConnectionIDLength. connIDLen int + // Set in init. + // If no ConnectionIDGenerator is set, this is set to a default. + connIDGenerator ConnectionIDGenerator server unknownPacketHandler @@ -92,10 +95,10 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) return nil, errListenerAlreadySet } conf = populateServerConfig(conf) - if err := t.init(conf); err != nil { + if err := t.init(conf, true); err != nil { return nil, err } - s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, false) + s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, false) if err != nil { return nil, err } @@ -121,10 +124,10 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen return nil, errListenerAlreadySet } conf = populateServerConfig(conf) - if err := t.init(conf); err != nil { + if err := t.init(conf, true); err != nil { return nil, err } - s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, true) + s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, true) if err != nil { return nil, err } @@ -137,15 +140,15 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config if err := validateConfig(conf); err != nil { return nil, err } - conf = populateClientConfig(conf, t.createdConn) - if err := t.init(conf); err != nil { + conf = populateConfig(conf) + if err := t.init(conf, false); err != nil { return nil, err } var onClose func() if t.isSingleUse { onClose = func() { t.Close() } } - return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn) + return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn) } // DialEarly dials a new connection, attempting to use 0-RTT if possible. @@ -153,15 +156,15 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C if err := validateConfig(conf); err != nil { return nil, err } - conf = populateClientConfig(conf, t.createdConn) - if err := t.init(conf); err != nil { + conf = populateConfig(conf) + if err := t.init(conf, false); err != nil { return nil, err } var onClose func() if t.isSingleUse { onClose = func() { t.Close() } } - return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn) + return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn) } func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { @@ -197,7 +200,7 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { // only print warnings about the UDP receive buffer size once var receiveBufferWarningOnce sync.Once -func (t *Transport) init(conf *Config) error { +func (t *Transport) init(conf *Config, isServer bool) error { t.initOnce.Do(func() { getMultiplexer().AddConn(t.Conn) @@ -208,9 +211,6 @@ func (t *Transport) init(conf *Config) error { } t.Tracer = conf.Tracer - t.ConnectionIDLength = conf.ConnectionIDLength - t.ConnectionIDGenerator = conf.ConnectionIDGenerator - t.logger = utils.DefaultLogger // TODO: make this configurable t.conn = conn t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) @@ -219,9 +219,15 @@ func (t *Transport) init(conf *Config) error { t.closeQueue = make(chan closePacket, 4) if t.ConnectionIDGenerator != nil { + t.connIDGenerator = t.ConnectionIDGenerator t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen() } else { - t.connIDLen = t.ConnectionIDLength + connIDLen := t.ConnectionIDLength + if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) { + connIDLen = protocol.DefaultConnectionIDLength + } + t.connIDLen = connIDLen + t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen} } go t.listen(conn) diff --git a/transport_test.go b/transport_test.go index c404a3b92..19ef54f20 100644 --- a/transport_test.go +++ b/transport_test.go @@ -61,7 +61,7 @@ var _ = Describe("Transport", func() { It("handles packets for different packet handlers on the same packet conn", func() { packetChan := make(chan packetToRead) tr := &Transport{Conn: newMockPacketConn(packetChan)} - tr.init(&Config{}) + tr.init(&Config{}, true) phm := NewMockPacketHandlerManager(mockCtrl) tr.handlerMap = phm connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) @@ -126,9 +126,10 @@ var _ = Describe("Transport", func() { packetChan := make(chan packetToRead) tracer := mocklogging.NewMockTracer(mockCtrl) tr := &Transport{ - Conn: newMockPacketConn(packetChan), + Conn: newMockPacketConn(packetChan), + ConnectionIDLength: 10, } - tr.init(&Config{Tracer: tracer, ConnectionIDLength: 10}) + tr.init(&Config{Tracer: tracer}, true) dropped := make(chan struct{}) tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }) packetChan <- packetToRead{ @@ -147,7 +148,7 @@ var _ = Describe("Transport", func() { tr := Transport{Conn: newMockPacketConn(packetChan)} defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}) + tr.init(&Config{}, true) tr.handlerMap = phm done := make(chan struct{}) @@ -165,7 +166,7 @@ var _ = Describe("Transport", func() { tr := Transport{Conn: newMockPacketConn(packetChan)} defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}) + tr.init(&Config{}, true) tr.handlerMap = phm tempErr := deadlineError{} @@ -183,11 +184,13 @@ var _ = Describe("Transport", func() { It("handles short header packets resets", func() { connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) packetChan := make(chan packetToRead) - tr := Transport{Conn: newMockPacketConn(packetChan)} - tr.init(&Config{ConnectionIDLength: connID.Len()}) + tr := Transport{ + Conn: newMockPacketConn(packetChan), + ConnectionIDLength: connID.Len(), + } + tr.init(&Config{}, true) defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}) tr.handlerMap = phm var token protocol.StatelessResetToken @@ -218,10 +221,9 @@ var _ = Describe("Transport", func() { connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) packetChan := make(chan packetToRead) tr := Transport{Conn: newMockPacketConn(packetChan)} - tr.init(&Config{ConnectionIDLength: connID.Len()}) + tr.init(&Config{}, true) defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}) tr.handlerMap = phm var token protocol.StatelessResetToken @@ -251,13 +253,13 @@ var _ = Describe("Transport", func() { packetChan := make(chan packetToRead) conn := newMockPacketConn(packetChan) tr := Transport{ - Conn: conn, - StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, + Conn: conn, + StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, + ConnectionIDLength: connID.Len(), } - tr.init(&Config{ConnectionIDLength: connID.Len()}) + tr.init(&Config{}, true) defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}) tr.handlerMap = phm var b []byte From bc7cb706c55905ebbbaef3256794303f9d1ed35e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 25 Apr 2023 11:31:01 +0200 Subject: [PATCH 5/8] add a GetConfigForClient callback to the Config --- config.go | 1 + config_test.go | 7 +- integrationtests/self/handshake_test.go | 66 +++++++++++ interface.go | 7 ++ mock_packet_handler_manager_test.go | 2 +- packet_handler_map.go | 7 +- packet_handler_map_test.go | 8 +- server.go | 25 +++- server_test.go | 147 ++++++++++++++++++------ 9 files changed, 223 insertions(+), 47 deletions(-) diff --git a/config.go b/config.go index ceb4fb682..fbd6da179 100644 --- a/config.go +++ b/config.go @@ -103,6 +103,7 @@ func populateConfig(config *Config) *Config { } return &Config{ + GetConfigForClient: config.GetConfigForClient, Versions: versions, HandshakeIdleTimeout: handshakeIdleTimeout, MaxIdleTimeout: idleTimeout, diff --git a/config_test.go b/config_test.go index f319deb26..3de7a173e 100644 --- a/config_test.go +++ b/config_test.go @@ -1,6 +1,7 @@ package quic import ( + "errors" "fmt" "net" "reflect" @@ -45,7 +46,7 @@ var _ = Describe("Config", func() { } switch fn := typ.Field(i).Name; fn { - case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT": + case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT": // Can't compare functions. case "Versions": f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) @@ -108,6 +109,7 @@ var _ = Describe("Config", func() { It("clones function fields", func() { var calledAddrValidation, calledAllowConnectionWindowIncrease bool c1 := &Config{ + GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, } @@ -116,6 +118,8 @@ var _ = Describe("Config", func() { Expect(calledAddrValidation).To(BeTrue()) c2.AllowConnectionWindowIncrease(nil, 1234) Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) + _, err := c2.GetConfigForClient(&ClientHelloInfo{}) + Expect(err).To(MatchError("nope")) }) It("clones non-function fields", func() { @@ -164,6 +168,7 @@ var _ = Describe("Config", func() { Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams)) Expect(c.DisableVersionNegotiationPackets).To(BeFalse()) Expect(c.DisablePathMTUDiscovery).To(BeFalse()) + Expect(c.GetConfigForClient).To(BeNil()) }) It("populates empty fields with default values, for the server", func() { diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index fc77f4244..0f1c36786 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -436,6 +436,72 @@ var _ = Describe("Handshake tests", func() { }) }) + Context("GetConfigForClient", func() { + It("uses the quic.Config returned by GetConfigForClient", func() { + serverConfig.EnableDatagrams = false + var calledFrom net.Addr + serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) { + conf := serverConfig.Clone() + conf.EnableDatagrams = true + calledFrom = info.RemoteAddr + return getQuicConfig(conf), nil + } + ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) + Expect(err).ToNot(HaveOccurred()) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + conn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + getQuicConfig(&quic.Config{EnableDatagrams: true}), + ) + Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") + cs := conn.ConnectionState() + Expect(cs.SupportsDatagrams).To(BeTrue()) + Eventually(done).Should(BeClosed()) + Expect(ln.Close()).To(Succeed()) + Expect(calledFrom.(*net.UDPAddr).Port).To(Equal(conn.LocalAddr().(*net.UDPAddr).Port)) + }) + + It("rejects the connection attempt if GetConfigForClient errors", func() { + serverConfig.EnableDatagrams = false + serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) { + return nil, errors.New("rejected") + } + ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := ln.Accept(context.Background()) + Expect(err).To(HaveOccurred()) // we don't expect to accept any connection + close(done) + }() + + _, err = quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + getQuicConfig(&quic.Config{EnableDatagrams: true}), + ) + Expect(err).To(HaveOccurred()) + var transportErr *quic.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused)) + }) + }) + It("doesn't send any packets when generating the ClientHello fails", func() { ln, err := net.ListenUDP("udp", nil) Expect(err).ToNot(HaveOccurred()) diff --git a/interface.go b/interface.go index 267c07f5c..29a20958e 100644 --- a/interface.go +++ b/interface.go @@ -239,6 +239,9 @@ type ConnectionIDGenerator interface { // Config contains all configuration data needed for a QUIC server or client. type Config struct { + // GetConfigForClient is called for incoming connections. + // If the error is not nil, the connection attempt is refused. + GetConfigForClient func(info *ClientHelloInfo) (*Config, error) // The QUIC versions that can be negotiated. // If not set, it uses all versions available. Versions []VersionNumber @@ -324,6 +327,10 @@ type Config struct { Tracer logging.Tracer } +type ClientHelloInfo struct { + RemoteAddr net.Addr +} + // ConnectionState records basic details about a QUIC connection type ConnectionState struct { TLS handshake.ConnectionState diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 25ae5420c..7b70a8dbd 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -61,7 +61,7 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interfa } // AddWithConnID mocks base method. -func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() packetHandler) bool { +func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() (packetHandler, bool)) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2) ret0, _ := ret[0].(bool) diff --git a/packet_handler_map.go b/packet_handler_map.go index 2a08359a9..83caa1920 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -122,7 +122,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) return true } -func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool { +func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool { h.mutex.Lock() defer h.mutex.Unlock() @@ -130,7 +130,10 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) return false } - conn := fn() + conn, ok := fn() + if !ok { + return false + } h.handlers[clientDestConnID] = conn h.handlers[newConnID] = conn h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index e87a75f88..2969bb5b9 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -62,14 +62,14 @@ var _ = Describe("Packet Handler Map", func() { var called bool connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) - Expect(m.AddWithConnID(connID1, connID2, func() packetHandler { + Expect(m.AddWithConnID(connID1, connID2, func() (packetHandler, bool) { called = true - return NewMockPacketHandler(mockCtrl) + return NewMockPacketHandler(mockCtrl), true })).To(BeTrue()) Expect(called).To(BeTrue()) - Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() packetHandler { + Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() (packetHandler, bool) { Fail("didn't expect the constructor to be executed") - return nil + return nil, false })).To(BeFalse()) }) diff --git a/server.go b/server.go index a80709ea3..a38a57848 100644 --- a/server.go +++ b/server.go @@ -33,7 +33,7 @@ type packetHandler interface { type packetHandlerManager interface { Get(protocol.ConnectionID) (packetHandler, bool) GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool) - AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool + AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool Close(error) CloseServer() connRunner @@ -584,7 +584,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro s.logger.Debugf("Changing connection ID to %s.", connID) var conn quicConn tracingID := nextConnTracingID() - if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { + if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) { var tracer logging.ConnectionTracer if s.config.Tracer != nil { // Use the same connection ID that is passed to the client's GetLogWriter callback. @@ -598,6 +598,15 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro connID, ) } + config := s.config + if s.config.GetConfigForClient != nil { + conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr}) + if err != nil { + s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback") + return nil, false + } + config = populateConfig(conf) + } conn = s.newConn( newSendConn(s.conn, p.remoteAddr, p.info), s.connHandler, @@ -608,7 +617,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro connID, s.connIDGenerator, s.connHandler.GetStatelessResetToken(connID), - s.config, + config, s.tlsConf, s.tokenGenerator, clientAddrIsValid, @@ -626,10 +635,14 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro delete(s.zeroRTTQueues, hdr.DestConnectionID) } - return conn + return conn, true }); !added { - // TODO: don't just drop the packet - // Properly reject the connection attempt. + go func() { + defer p.buffer.Release() + if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil { + s.logger.Debugf("Error rejecting connection: %s", err) + } + }() return nil } go conn.run() diff --git a/server_test.go b/server_test.go index 7a17ffd51..e78356c81 100644 --- a/server_test.go +++ b/server_test.go @@ -267,14 +267,14 @@ var _ = Describe("Server", func() { var newConnID protocol.ConnectionID phm.EXPECT().Get(connID) - phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool { newConnID = c phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { newConnID = c return token }) - fn() - return true + _, ok := fn() + return ok }) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})) conn := NewMockQUICConn(mockCtrl) @@ -468,14 +468,14 @@ var _ = Describe("Server", func() { var newConnID protocol.ConnectionID gomock.InOrder( phm.EXPECT().Get(connID), - phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool { newConnID = c phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { newConnID = c return token }) - fn() - return true + _, ok := fn() + return ok }), ) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) @@ -532,10 +532,10 @@ var _ = Describe("Server", func() { It("drops packets if the receive queue is full", func() { phm.EXPECT().Get(gomock.Any()).AnyTimes() - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }).AnyTimes() tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() @@ -594,7 +594,6 @@ var _ = Describe("Server", func() { It("only creates a single connection for a duplicate Initial", func() { var createdConn bool - conn := NewMockQUICConn(mockCtrl) serv.newConn = func( _ sendConn, runner connRunner, @@ -615,15 +614,19 @@ var _ = Describe("Server", func() { _ protocol.VersionNumber, ) quicConn { createdConn = true - return conn + return NewMockQUICConn(mockCtrl) } connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) p := getInitial(connID) phm.EXPECT().Get(connID) phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) { close(done) }) Expect(serv.handlePacketImpl(p)).To(BeTrue()) Expect(createdConn).To(BeFalse()) + Eventually(done).Should(BeClosed()) }) It("rejects new connection attempts if the accept queue is full", func() { @@ -657,10 +660,10 @@ var _ = Describe("Server", func() { } phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }).Times(protocol.MaxAcceptQueueSize) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize) @@ -729,10 +732,10 @@ var _ = Describe("Server", func() { } phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) @@ -792,7 +795,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) }) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() (packetHandler, bool)) { close(done) }) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) @@ -994,6 +997,84 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) + It("uses the config returned by GetConfigClient", func() { + conn := NewMockQUICConn(mockCtrl) + + conf := &Config{MaxIncomingStreams: 1234} + serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }}) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + s, err := serv.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(s).To(Equal(conn)) + close(done) + }() + + handshakeChan := make(chan struct{}) + serv.newConn = func( + _ sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ ConnectionIDGenerator, + _ protocol.StatelessResetToken, + conf *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234)) + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) + conn.EXPECT().run().Do(func() {}) + conn.EXPECT().Context().Return(context.Background()) + return conn + } + phm.EXPECT().Get(gomock.Any()) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + _, ok := fn() + return ok + }) + serv.handleInitialImpl( + &receivedPacket{buffer: getPacketBuffer()}, + &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, + ) + Consistently(done).ShouldNot(BeClosed()) + close(handshakeChan) // complete the handshake + Eventually(done).Should(BeClosed()) + }) + + It("rejects a connection attempt when GetConfigClient returns an error", func() { + serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }}) + + phm.EXPECT().Get(gomock.Any()) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { + _, ok := fn() + return ok + }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + rejectHdr := parseHeader(b) + Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) + return len(b), nil + }) + serv.handleInitialImpl( + &receivedPacket{buffer: getPacketBuffer()}, + &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1}, + ) + Eventually(done).Should(BeClosed()) + }) + It("accepts new connections when the handshake completes", func() { conn := NewMockQUICConn(mockCtrl) @@ -1033,10 +1114,10 @@ var _ = Describe("Server", func() { return conn } phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) serv.handleInitialImpl( @@ -1107,10 +1188,10 @@ var _ = Describe("Server", func() { return conn } phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }) serv.baseServer.handleInitialImpl( &receivedPacket{buffer: getPacketBuffer()}, @@ -1154,10 +1235,10 @@ var _ = Describe("Server", func() { } phm.EXPECT().Get(gomock.Any()).AnyTimes() - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }).Times(protocol.MaxAcceptQueueSize) for i := 0; i < protocol.MaxAcceptQueueSize; i++ { serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) @@ -1216,10 +1297,10 @@ var _ = Describe("Server", func() { } phm.EXPECT().Get(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }) serv.baseServer.handlePacket(p) // make sure there are no Write calls on the packet conn @@ -1346,10 +1427,10 @@ var _ = Describe("Server", func() { tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any()) phm.EXPECT().Get(connID) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) - fn() - return true + _, ok := fn() + return ok }) serv.handlePacket(initial) Eventually(called).Should(BeClosed()) From 7a0ef5f8675977ae661f313d0597b87b523f8ee6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 22 Mar 2023 13:25:09 +1300 Subject: [PATCH 6/8] make Config.Allow0RTT a bool, not a callback --- config_test.go | 4 +++- connection.go | 6 +----- fuzzing/handshake/cmd/corpus.go | 2 +- fuzzing/handshake/fuzz.go | 6 +----- http3/server.go | 2 +- integrationtests/self/zero_rtt_test.go | 24 ++++++++++++------------ interface.go | 4 +--- internal/handshake/crypto_setup.go | 10 +++++----- internal/handshake/crypto_setup_test.go | 20 ++++++++------------ interop/server/main.go | 6 ++---- 10 files changed, 35 insertions(+), 49 deletions(-) diff --git a/config_test.go b/config_test.go index 3de7a173e..d796b3643 100644 --- a/config_test.go +++ b/config_test.go @@ -46,7 +46,7 @@ var _ = Describe("Config", func() { } switch fn := typ.Field(i).Name; fn { - case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT": + case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease": // Can't compare functions. case "Versions": f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) @@ -86,6 +86,8 @@ var _ = Describe("Config", func() { f.Set(reflect.ValueOf(true)) case "DisablePathMTUDiscovery": f.Set(reflect.ValueOf(true)) + case "Allow0RTT": + f.Set(reflect.ValueOf(true)) case "Tracer": f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl))) default: diff --git a/connection.go b/connection.go index 78969ed91..6833a7b94 100644 --- a/connection.go +++ b/connection.go @@ -324,10 +324,6 @@ var newConnection = func( if s.tracer != nil { s.tracer.SentTransportParameters(params) } - var allow0RTT func() bool - if conf.Allow0RTT != nil { - allow0RTT = func() bool { return conf.Allow0RTT(conn.RemoteAddr()) } - } cs := handshake.NewCryptoSetupServer( initialStream, handshakeStream, @@ -345,7 +341,7 @@ var newConnection = func( }, }, tlsConf, - allow0RTT, + conf.Allow0RTT, s.rttStats, tracer, logger, diff --git a/fuzzing/handshake/cmd/corpus.go b/fuzzing/handshake/cmd/corpus.go index a5e0ffb40..1142bea61 100644 --- a/fuzzing/handshake/cmd/corpus.go +++ b/fuzzing/handshake/cmd/corpus.go @@ -105,7 +105,7 @@ func main() { &wire.TransportParameters{ActiveConnectionIDLimit: 2}, runner, config, - nil, + false, utils.NewRTTStats(), nil, utils.DefaultLogger.WithPrefix("server"), diff --git a/fuzzing/handshake/fuzz.go b/fuzzing/handshake/fuzz.go index 37055092c..2d73e6056 100644 --- a/fuzzing/handshake/fuzz.go +++ b/fuzzing/handshake/fuzz.go @@ -390,10 +390,6 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. protocol.Version1, ) - var allow0RTT func() bool - if enable0RTTServer { - allow0RTT = func() bool { return true } - } sChunkChan, sInitialStream, sHandshakeStream := initStreams() server = handshake.NewCryptoSetupServer( sInitialStream, @@ -404,7 +400,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. serverTP, runner, serverConf, - allow0RTT, + enable0RTTServer, utils.NewRTTStats(), nil, utils.DefaultLogger.WithPrefix("server"), diff --git a/http3/server.go b/http3/server.go index 09b62a2b6..b0d8e4a40 100644 --- a/http3/server.go +++ b/http3/server.go @@ -288,7 +288,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error { baseConf := ConfigureTLSConfig(tlsConf) quicConf := s.QuicConfig if quicConf == nil { - quicConf = &quic.Config{Allow0RTT: func(net.Addr) bool { return true }} + quicConf = &quic.Config{Allow0RTT: true} } else { quicConf = s.QuicConfig.Clone() } diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index e83a9f965..684cfe6ae 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -54,7 +54,7 @@ var _ = Describe("0-RTT", func() { if serverConf == nil { serverConf = getQuicConfig(nil) } - serverConf.Allow0RTT = func(addr net.Addr) bool { return true } + serverConf.Allow0RTT = true ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -222,7 +222,7 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Allow0RTT: func(addr net.Addr) bool { return true }, + Allow0RTT: true, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) @@ -276,7 +276,7 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Allow0RTT: func(net.Addr) bool { return true }, + Allow0RTT: true, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) @@ -358,7 +358,7 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Allow0RTT: func(net.Addr) bool { return true }, + Allow0RTT: true, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) @@ -434,7 +434,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ RequireAddressValidation: func(net.Addr) bool { return true }, - Allow0RTT: func(net.Addr) bool { return true }, + Allow0RTT: true, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) @@ -495,7 +495,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ MaxIncomingUniStreams: maxStreams + 1, - Allow0RTT: func(net.Addr) bool { return true }, + Allow0RTT: true, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) @@ -540,7 +540,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ MaxIncomingStreams: maxStreams - 1, - Allow0RTT: func(net.Addr) bool { return true }, + Allow0RTT: true, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) @@ -568,7 +568,7 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Allow0RTT: func(net.Addr) bool { return true }, + Allow0RTT: true, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) @@ -595,7 +595,7 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Allow0RTT: func(net.Addr) bool { return false }, // application rejects 0-RTT + Allow0RTT: false, // application rejects 0-RTT Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) @@ -616,12 +616,12 @@ var _ = Describe("0-RTT", func() { DescribeTable("flow control limits", func(addFlowControlLimit func(*quic.Config, uint64)) { tracer := newPacketTracer() - firstConf := getQuicConfig(&quic.Config{Allow0RTT: func(net.Addr) bool { return true }}) + firstConf := getQuicConfig(&quic.Config{Allow0RTT: true}) addFlowControlLimit(firstConf, 3) tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) secondConf := getQuicConfig(&quic.Config{ - Allow0RTT: func(net.Addr) bool { return true }, + Allow0RTT: true, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }) addFlowControlLimit(secondConf, 100) @@ -774,7 +774,7 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Allow0RTT: func(net.Addr) bool { return true }, + Allow0RTT: true, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) diff --git a/interface.go b/interface.go index 29a20958e..4d2a91976 100644 --- a/interface.go +++ b/interface.go @@ -318,10 +318,8 @@ type Config struct { // It has no effect for a client. DisableVersionNegotiationPackets bool // Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted. - // When set, 0-RTT is enabled. When not set, 0-RTT is disabled. // Only valid for the server. - // Warning: This API should not be considered stable and might change soon. - Allow0RTT func(net.Addr) bool + Allow0RTT bool // Enable QUIC datagram support (RFC 9221). EnableDatagrams bool Tracer logging.Tracer diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 0420a5f92..8c9c2a8f8 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -116,7 +116,7 @@ type cryptoSetup struct { clientHelloWritten bool clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written zeroRTTParametersChan chan<- *wire.TransportParameters - allow0RTT func() bool + allow0RTT bool rttStats *utils.RTTStats @@ -197,7 +197,7 @@ func NewCryptoSetupServer( tp *wire.TransportParameters, runner handshakeRunner, tlsConf *tls.Config, - allow0RTT func() bool, + allow0RTT bool, rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, @@ -210,14 +210,13 @@ func NewCryptoSetupServer( tp, runner, tlsConf, - allow0RTT != nil, + allow0RTT, rttStats, tracer, logger, protocol.PerspectiveServer, version, ) - cs.allow0RTT = allow0RTT cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf) return cs } @@ -253,6 +252,7 @@ func newCryptoSetup( readEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial, runner: runner, + allow0RTT: enable0RTT, ourParams: tp, paramsChan: extHandler.TransportParameters(), rttStats: rttStats, @@ -503,7 +503,7 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.") return false } - if !h.allow0RTT() { + if !h.allow0RTT { h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.") return false } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index d21177248..51d1980a0 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -95,7 +95,7 @@ var _ = Describe("Crypto Setup TLS", func() { &wire.TransportParameters{StatelessResetToken: &token}, runner, testdata.GetTLSConfig(), - nil, + false, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), @@ -177,7 +177,7 @@ var _ = Describe("Crypto Setup TLS", func() { &wire.TransportParameters{StatelessResetToken: &token}, runner, testdata.GetTLSConfig(), - nil, + false, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), @@ -218,7 +218,7 @@ var _ = Describe("Crypto Setup TLS", func() { &wire.TransportParameters{StatelessResetToken: &token}, runner, serverConf, - nil, + false, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), @@ -253,7 +253,7 @@ var _ = Describe("Crypto Setup TLS", func() { &wire.TransportParameters{StatelessResetToken: &token}, NewMockHandshakeRunner(mockCtrl), serverConf, - nil, + false, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), @@ -378,10 +378,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.Version1, ) - var allow0RTT func() bool - if enable0RTT { - allow0RTT = func() bool { return true } - } var sHandshakeComplete bool sChunkChan, sInitialStream, sHandshakeStream := initStreams() sErrChan := make(chan error, 1) @@ -402,7 +398,7 @@ var _ = Describe("Crypto Setup TLS", func() { serverTransportParameters, sRunner, serverConf, - allow0RTT, + enable0RTT, serverRTTStats, nil, utils.DefaultLogger.WithPrefix("server"), @@ -541,7 +537,7 @@ var _ = Describe("Crypto Setup TLS", func() { sTransportParameters, sRunner, serverConf, - nil, + false, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), @@ -596,7 +592,7 @@ var _ = Describe("Crypto Setup TLS", func() { &wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token}, sRunner, serverConf, - nil, + false, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), @@ -655,7 +651,7 @@ var _ = Describe("Crypto Setup TLS", func() { &wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token}, sRunner, serverConf, - nil, + false, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), diff --git a/interop/server/main.go b/interop/server/main.go index 867044b48..6c607c5c0 100644 --- a/interop/server/main.go +++ b/interop/server/main.go @@ -46,6 +46,7 @@ func main() { // a quic.Config that doesn't do a Retry quicConf := &quic.Config{ RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" }, + Allow0RTT: testcase == "zerortt", Tracer: qlog.NewTracer(getLogWriter), } cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key") @@ -59,10 +60,7 @@ func main() { } switch testcase { - case "zerortt": - quicConf.Allow0RTT = func(net.Addr) bool { return true } - fallthrough - case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "multiconnect": + case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "multiconnect", "zerortt": err = runHTTP09Server(quicConf) case "chacha20": reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256) From 5544f0f9a1ef627fde4c6bc92d0f928d53f7edd6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 25 Apr 2023 13:48:57 +0200 Subject: [PATCH 7/8] simplify connection handling for the client --- client.go | 51 +++++++++++++++++------------------------------ client_test.go | 22 ++++++++++---------- send_conn.go | 23 +-------------------- send_conn_test.go | 4 +++- transport.go | 4 ++-- 5 files changed, 34 insertions(+), 70 deletions(-) diff --git a/client.go b/client.go index ed6ccfb82..e1f03c87d 100644 --- a/client.go +++ b/client.go @@ -12,10 +12,7 @@ import ( ) type client struct { - sconn sendConn - // If the client is created with DialAddr, we create a packet conn. - // If it is started with Dial, we take a packet conn as a parameter. - createdPacketConn bool + sendConn sendConn use0RTT bool @@ -133,17 +130,15 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo func dial( ctx context.Context, - conn net.PacketConn, + conn sendConn, connIDGenerator ConnectionIDGenerator, packetHandlers packetHandlerManager, - addr net.Addr, tlsConf *tls.Config, config *Config, onClose func(), use0RTT bool, - createdPacketConn bool, ) (quicConn, error) { - c, err := newClient(conn, addr, connIDGenerator, config, tlsConf, onClose, use0RTT, createdPacketConn) + c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT) if err != nil { return nil, err } @@ -158,7 +153,7 @@ func dial( ) } if c.tracer != nil { - c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID) + c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) } if err := c.dial(ctx); err != nil { return nil, err @@ -166,16 +161,7 @@ func dial( return c.conn, nil } -func newClient( - pconn net.PacketConn, - remoteAddr net.Addr, - connIDGenerator ConnectionIDGenerator, - config *Config, - tlsConf *tls.Config, - onClose func(), - use0RTT bool, - createdPacketConn bool, -) (*client, error) { +func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) { if tlsConf == nil { tlsConf = &tls.Config{} } else { @@ -191,27 +177,26 @@ func newClient( return nil, err } c := &client{ - connIDGenerator: connIDGenerator, - srcConnID: srcConnID, - destConnID: destConnID, - sconn: newSendPconn(pconn, remoteAddr), - createdPacketConn: createdPacketConn, - use0RTT: use0RTT, - onClose: onClose, - tlsConf: tlsConf, - config: config, - version: config.Versions[0], - handshakeChan: make(chan struct{}), - logger: utils.DefaultLogger.WithPrefix("client"), + connIDGenerator: connIDGenerator, + srcConnID: srcConnID, + destConnID: destConnID, + sendConn: sendConn, + use0RTT: use0RTT, + onClose: onClose, + tlsConf: tlsConf, + config: config, + version: config.Versions[0], + handshakeChan: make(chan struct{}), + logger: utils.DefaultLogger.WithPrefix("client"), } return c, nil } func (c *client) dial(ctx context.Context) error { - c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) + c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) c.conn = newClientConnection( - c.sconn, + c.sendConn, c.packetHandlers, c.destConnID, c.srcConnID, diff --git a/client_test.go b/client_test.go index ce53ef4b9..a3c31e1db 100644 --- a/client_test.go +++ b/client_test.go @@ -26,8 +26,7 @@ func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil } var _ = Describe("Client", func() { var ( cl *client - packetConn *MockPacketConn - addr net.Addr + packetConn *MockSendConn connID protocol.ConnectionID origMultiplexer multiplexer tlsConf *tls.Config @@ -62,14 +61,14 @@ var _ = Describe("Client", func() { tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}} Eventually(areConnsRunning).Should(BeFalse()) - addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - packetConn = NewMockPacketConn(mockCtrl) + packetConn = NewMockSendConn(mockCtrl) packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() + packetConn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes() cl = &client{ srcConnID: connID, destConnID: connID, version: protocol.Version1, - sconn: newSendPconn(packetConn, addr), + sendConn: packetConn, tracer: tracer, logger: utils.DefaultLogger, } @@ -134,7 +133,7 @@ var _ = Describe("Client", func() { conn.EXPECT().HandshakeComplete().Return(c) return conn } - cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false, false) + cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -171,7 +170,7 @@ var _ = Describe("Client", func() { return conn } - cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true, false) + cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -207,7 +206,7 @@ var _ = Describe("Client", func() { return conn } var closed bool - cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true, false) + cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -266,7 +265,6 @@ var _ = Describe("Client", func() { It("creates new connections with the right parameters", func() { config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}} c := make(chan struct{}) - var cconn sendConn var version protocol.VersionNumber var conf *Config done := make(chan struct{}) @@ -286,7 +284,6 @@ var _ = Describe("Client", func() { _ utils.Logger, versionP protocol.VersionNumber, ) quicConn { - cconn = connP version = versionP conf = configP close(c) @@ -298,15 +295,16 @@ var _ = Describe("Client", func() { close(done) return conn } + packetConn := NewMockPacketConn(mockCtrl) packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) { <-done return 0, nil, errors.New("closed") }) + packetConn.EXPECT().LocalAddr() packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes() - _, err := Dial(context.Background(), packetConn, addr, tlsConf, config) + _, err := Dial(context.Background(), packetConn, &net.UDPAddr{}, tlsConf, config) Expect(err).ToNot(HaveOccurred()) Eventually(c).Should(BeClosed()) - Expect(cconn.(*spconn).PacketConn).To(Equal(packetConn)) Expect(version).To(Equal(config.Versions[0])) Expect(conf.Versions).To(Equal(config.Versions)) }) diff --git a/send_conn.go b/send_conn.go index c53ebdfab..0ac270378 100644 --- a/send_conn.go +++ b/send_conn.go @@ -22,7 +22,7 @@ type sconn struct { var _ sendConn = &sconn{} -func newSendConn(c rawConn, remote net.Addr, info *packetInfo) sendConn { +func newSendConn(c rawConn, remote net.Addr, info *packetInfo) *sconn { return &sconn{ rawConn: c, remoteAddr: remote, @@ -51,24 +51,3 @@ func (c *sconn) LocalAddr() net.Addr { } return addr } - -type spconn struct { - net.PacketConn - - remoteAddr net.Addr -} - -var _ sendConn = &spconn{} - -func newSendPconn(c net.PacketConn, remote net.Addr) sendConn { - return &spconn{PacketConn: c, remoteAddr: remote} -} - -func (c *spconn) Write(p []byte) error { - _, err := c.WriteTo(p, c.remoteAddr) - return err -} - -func (c *spconn) RemoteAddr() net.Addr { - return c.remoteAddr -} diff --git a/send_conn_test.go b/send_conn_test.go index 6c36c1b65..2da3e3ab9 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -17,7 +17,9 @@ var _ = Describe("Connection (for sending packets)", func() { BeforeEach(func() { addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} packetConn = NewMockPacketConn(mockCtrl) - c = newSendPconn(packetConn, addr) + rawConn, err := wrapConn(packetConn) + Expect(err).ToNot(HaveOccurred()) + c = newSendConn(rawConn, addr, nil) }) It("writes", func() { diff --git a/transport.go b/transport.go index ebb144f41..baeb592b7 100644 --- a/transport.go +++ b/transport.go @@ -148,7 +148,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config if t.isSingleUse { onClose = func() { t.Close() } } - return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn) + return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false) } // DialEarly dials a new connection, attempting to use 0-RTT if possible. @@ -164,7 +164,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C if t.isSingleUse { onClose = func() { t.Close() } } - return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn) + return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) } func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { From 07ad2cbee28288aab7cb36d8943fd9444cf83044 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 28 Apr 2023 12:49:11 +0200 Subject: [PATCH 8/8] remove Tracer from Config, put ConnectionTracer constructor there --- client.go | 6 +- client_test.go | 10 +- config_test.go | 15 ++- example/client/main.go | 7 +- example/main.go | 7 +- integrationtests/self/key_update_test.go | 4 +- integrationtests/self/packetization_test.go | 4 +- integrationtests/self/self_suite_test.go | 25 ++-- integrationtests/self/timeout_test.go | 2 +- integrationtests/self/tracer_test.go | 34 +++--- integrationtests/self/zero_rtt_test.go | 22 ++-- integrationtests/tools/qlog.go | 12 +- .../versionnegotiation/handshake_test.go | 20 +++- .../versionnegotiation/rtt_test.go | 2 +- .../versionnegotiation_suite_test.go | 25 ++-- interface.go | 2 +- internal/mocks/logging/tracer.go | 15 --- interop/client/main.go | 7 +- interop/server/main.go | 9 +- interop/utils/logging.go | 29 ++--- logging/interface.go | 7 -- logging/mock_tracer_test.go | 15 --- logging/multiplex.go | 11 -- logging/multiplex_test.go | 41 ------- logging/null_tracer.go | 4 - qlog/qlog.go | 21 ---- qlog/qlog_test.go | 17 +-- server.go | 108 ++++++++++-------- server_test.go | 22 ++-- transport.go | 15 ++- transport_test.go | 15 +-- 31 files changed, 202 insertions(+), 331 deletions(-) diff --git a/client.go b/client.go index e1f03c87d..3d0bca3df 100644 --- a/client.go +++ b/client.go @@ -146,11 +146,7 @@ func dial( c.tracingID = nextConnTracingID() if c.config.Tracer != nil { - c.tracer = c.config.Tracer.TracerForConnection( - context.WithValue(ctx, ConnectionTracingKey, c.tracingID), - protocol.PerspectiveClient, - c.destConnID, - ) + c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) } if c.tracer != nil { c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) diff --git a/client_test.go b/client_test.go index a3c31e1db..23847329c 100644 --- a/client_test.go +++ b/client_test.go @@ -56,10 +56,12 @@ var _ = Describe("Client", func() { connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37}) originalClientConnConstructor = newClientConnection tracer = mocklogging.NewMockConnectionTracer(mockCtrl) - tr := mocklogging.NewMockTracer(mockCtrl) - tr.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) - config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}} + config = &Config{ + Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) logging.ConnectionTracer { + return tracer + }, + Versions: []protocol.VersionNumber{protocol.Version1}, + } Eventually(areConnsRunning).Should(BeFalse()) packetConn = NewMockSendConn(mockCtrl) packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() diff --git a/config_test.go b/config_test.go index d796b3643..53482ca48 100644 --- a/config_test.go +++ b/config_test.go @@ -1,14 +1,15 @@ package quic import ( + "context" "errors" "fmt" "net" "reflect" "time" - mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -46,7 +47,7 @@ var _ = Describe("Config", func() { } switch fn := typ.Field(i).Name; fn { - case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease": + case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer": // Can't compare functions. case "Versions": f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) @@ -88,8 +89,6 @@ var _ = Describe("Config", func() { f.Set(reflect.ValueOf(true)) case "Allow0RTT": f.Set(reflect.ValueOf(true)) - case "Tracer": - f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl))) default: Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn)) } @@ -109,11 +108,15 @@ var _ = Describe("Config", func() { Context("cloning", func() { It("clones function fields", func() { - var calledAddrValidation, calledAllowConnectionWindowIncrease bool + var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool c1 := &Config{ GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, + Tracer: func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer { + calledTracer = true + return nil + }, } c2 := c1.Clone() c2.RequireAddressValidation(&net.UDPAddr{}) @@ -122,6 +125,8 @@ var _ = Describe("Config", func() { Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) _, err := c2.GetConfigForClient(&ClientHelloInfo{}) Expect(err).To(MatchError("nope")) + c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{}) + Expect(calledTracer).To(BeTrue()) }) It("clones non-function fields", func() { diff --git a/example/client/main.go b/example/client/main.go index 19da87262..83f810fd1 100644 --- a/example/client/main.go +++ b/example/client/main.go @@ -3,6 +3,7 @@ package main import ( "bufio" "bytes" + "context" "crypto/tls" "crypto/x509" "flag" @@ -57,15 +58,15 @@ func main() { var qconf quic.Config if *enableQlog { - qconf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser { + qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { filename := fmt.Sprintf("client_%x.qlog", connID) f, err := os.Create(filename) if err != nil { log.Fatal(err) } log.Printf("Creating qlog file %s.\n", filename) - return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f) - }) + return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID) + } } roundTripper := &http3.RoundTripper{ TLSClientConfig: &tls.Config{ diff --git a/example/main.go b/example/main.go index ea2a0babe..058144050 100644 --- a/example/main.go +++ b/example/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "context" "crypto/md5" "errors" "flag" @@ -162,15 +163,15 @@ func main() { handler := setupHandler(*www) quicConf := &quic.Config{} if *enableQlog { - quicConf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser { + quicConf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { filename := fmt.Sprintf("server_%x.qlog", connID) f, err := os.Create(filename) if err != nil { log.Fatal(err) } log.Printf("Creating qlog file %s.\n", filename) - return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f) - }) + return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID) + } } var wg sync.WaitGroup diff --git a/integrationtests/self/key_update_test.go b/integrationtests/self/key_update_test.go index 51ae6be01..c24bef27d 100644 --- a/integrationtests/self/key_update_test.go +++ b/integrationtests/self/key_update_test.go @@ -75,7 +75,9 @@ var _ = Describe("Key Update tests", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - getQuicConfig(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return &keyUpdateConnTracer{} })}), + getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return &keyUpdateConnTracer{} + }}), ) Expect(err).ToNot(HaveOccurred()) str, err := conn.AcceptUniStream(context.Background()) diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index 86062bd56..740956c54 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -27,7 +27,7 @@ var _ = Describe("Packetization", func() { getTLSConfig(), getQuicConfig(&quic.Config{ DisablePathMTUDiscovery: true, - Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }), + Tracer: newTracer(serverTracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -50,7 +50,7 @@ var _ = Describe("Packetization", func() { getTLSClientConfig(), getQuicConfig(&quic.Config{ DisablePathMTUDiscovery: true, - Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }), + Tracer: newTracer(clientTracer), }), ) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index 966a57158..10adc856b 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -87,7 +87,7 @@ var ( logBuf *syncedBuffer versionParam string - qlogTracer logging.Tracer + qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer enableQlog bool version quic.VersionNumber @@ -175,7 +175,13 @@ func getQuicConfig(conf *quic.Config) *quic.Config { if conf.Tracer == nil { conf.Tracer = qlogTracer } else if qlogTracer != nil { - conf.Tracer = logging.NewMultiplexedTracer(qlogTracer, conf.Tracer) + origTracer := conf.Tracer + conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + return logging.NewMultiplexedConnectionTracer( + qlogTracer(ctx, p, connID), + origTracer(ctx, p, connID), + ) + } } } return conf @@ -232,19 +238,8 @@ func scaleDuration(d time.Duration) time.Duration { return time.Duration(scaleFactor) * d } -type tracer struct { - logging.NullTracer - createNewConnTracer func() logging.ConnectionTracer -} - -var _ logging.Tracer = &tracer{} - -func newTracer(c func() logging.ConnectionTracer) logging.Tracer { - return &tracer{createNewConnTracer: c} -} - -func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer { - return t.createNewConnTracer() +func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer } } type packet struct { diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index abc05dd79..5a01270d3 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -201,7 +201,7 @@ var _ = Describe("Timeout tests", func() { getTLSClientConfig(), getQuicConfig(&quic.Config{ MaxIdleTimeout: idleTimeout, - Tracer: newTracer(func() logging.ConnectionTracer { return tr }), + Tracer: newTracer(tr), DisablePathMTUDiscovery: true, }), ) diff --git a/integrationtests/self/tracer_test.go b/integrationtests/self/tracer_test.go index 377fbe1ed..3bfae3c6b 100644 --- a/integrationtests/self/tracer_test.go +++ b/integrationtests/self/tracer_test.go @@ -19,14 +19,6 @@ import ( . "github.com/onsi/gomega" ) -type customTracer struct{ logging.NullTracer } - -func (t *customTracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer { - return &customConnTracer{} -} - -type customConnTracer struct{ logging.NullConnectionTracer } - var _ = Describe("Handshake tests", func() { addTracers := func(pers protocol.Perspective, conf *quic.Config) *quic.Config { enableQlog := mrand.Int()%3 != 0 @@ -34,22 +26,32 @@ var _ = Describe("Handshake tests", func() { fmt.Fprintf(GinkgoWriter, "%s using qlog: %t, custom: %t\n", pers, enableQlog, enableCustomTracer) - var tracers []logging.Tracer + var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer if enableQlog { - tracers = append(tracers, qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser { + tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections - fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connectionID) + fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connID) return nil } - fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %x\n", p, connectionID) - return utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil)) - })) + fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %x\n", p, connID) + return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil)), p, connID) + }) } if enableCustomTracer { - tracers = append(tracers, &customTracer{}) + tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return logging.NullConnectionTracer{} + }) } c := conf.Clone() - c.Tracer = logging.NewMultiplexedTracer(tracers...) + c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + tracers := make([]logging.ConnectionTracer, 0, len(tracerConstructors)) + for _, c := range tracerConstructors { + if tr := c(ctx, p, connID); tr != nil { + tracers = append(tracers, tr) + } + } + return logging.NewMultiplexedConnectionTracer(tracers...) + } return c } diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 684cfe6ae..001ea5da2 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -223,7 +223,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -277,7 +277,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -359,7 +359,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -435,7 +435,7 @@ var _ = Describe("0-RTT", func() { getQuicConfig(&quic.Config{ RequireAddressValidation: func(net.Addr) bool { return true }, Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -496,7 +496,7 @@ var _ = Describe("0-RTT", func() { getQuicConfig(&quic.Config{ MaxIncomingUniStreams: maxStreams + 1, Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -541,7 +541,7 @@ var _ = Describe("0-RTT", func() { getQuicConfig(&quic.Config{ MaxIncomingStreams: maxStreams - 1, Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -569,7 +569,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -596,7 +596,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: false, // application rejects 0-RTT - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -622,7 +622,7 @@ var _ = Describe("0-RTT", func() { secondConf := getQuicConfig(&quic.Config{ Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }) addFlowControlLimit(secondConf, 100) ln, err := quic.ListenAddrEarly( @@ -699,7 +699,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ MaxIncomingUniStreams: 1, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -775,7 +775,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/tools/qlog.go b/integrationtests/tools/qlog.go index a0854260c..352e0a613 100644 --- a/integrationtests/tools/qlog.go +++ b/integrationtests/tools/qlog.go @@ -2,23 +2,25 @@ package tools import ( "bufio" + "context" "fmt" "io" "log" "os" + "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/logging" "github.com/quic-go/quic-go/qlog" ) -func NewQlogger(logger io.Writer) logging.Tracer { - return qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser { +func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { role := "server" if p == logging.PerspectiveClient { role = "client" } - filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role) + filename := fmt.Sprintf("log_%x_%s.qlog", connID.Bytes(), role) fmt.Fprintf(logger, "Creating %s.\n", filename) f, err := os.Create(filename) if err != nil { @@ -26,6 +28,6 @@ func NewQlogger(logger io.Writer) logging.Tracer { return nil } bw := bufio.NewWriter(f) - return utils.NewBufferedWriteCloser(bw, f) - }) + return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bw, f), p, connID) + } } diff --git a/integrationtests/versionnegotiation/handshake_test.go b/integrationtests/versionnegotiation/handshake_test.go index ad3b7f7a8..965700c15 100644 --- a/integrationtests/versionnegotiation/handshake_test.go +++ b/integrationtests/versionnegotiation/handshake_test.go @@ -85,7 +85,9 @@ var _ = Describe("Handshake tests", func() { serverConfig := &quic.Config{} serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9} serverTracer := &versionNegotiationTracer{} - serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer }) + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return serverTracer + } server, cl := startServer(getTLSConfig(), serverConfig) defer cl() clientTracer := &versionNegotiationTracer{} @@ -93,7 +95,9 @@ var _ = Describe("Handshake tests", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - maybeAddQlogTracer(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}), + maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) logging.ConnectionTracer { + return clientTracer + }}), ) Expect(err).ToNot(HaveOccurred()) Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion)) @@ -111,10 +115,12 @@ var _ = Describe("Handshake tests", func() { expectedVersion := protocol.SupportedVersions[0] // the server doesn't support the highest supported version, which is the first one the client will try // but it supports a bunch of versions that the client doesn't speak + serverTracer := &versionNegotiationTracer{} serverConfig := &quic.Config{} serverConfig.Versions = supportedVersions - serverTracer := &versionNegotiationTracer{} - serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer }) + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return serverTracer + } server, cl := startServer(getTLSConfig(), serverConfig) defer cl() clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} @@ -123,9 +129,11 @@ var _ = Describe("Handshake tests", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - maybeAddQlogTracer(&quic.Config{ + maybeAddQLOGTracer(&quic.Config{ Versions: clientVersions, - Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }), + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return clientTracer + }, }), ) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/versionnegotiation/rtt_test.go b/integrationtests/versionnegotiation/rtt_test.go index 2c8868dbd..ce9cba1be 100644 --- a/integrationtests/versionnegotiation/rtt_test.go +++ b/integrationtests/versionnegotiation/rtt_test.go @@ -47,7 +47,7 @@ var _ = Describe("Handshake RTT tests", func() { context.Background(), proxy.LocalAddr().String(), getTLSClientConfig(), - maybeAddQlogTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}), + maybeAddQLOGTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}), ) Expect(err).To(HaveOccurred()) expectDurationInRTTs(startTime, 1) diff --git a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go index 96a98fa8e..a01ac1f8a 100644 --- a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go +++ b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go @@ -58,7 +58,7 @@ func TestQuicVersionNegotiation(t *testing.T) { RunSpecs(t, "Version Negotiation Suite") } -func maybeAddQlogTracer(c *quic.Config) *quic.Config { +func maybeAddQLOGTracer(c *quic.Config) *quic.Config { if c == nil { c = &quic.Config{} } @@ -69,22 +69,13 @@ func maybeAddQlogTracer(c *quic.Config) *quic.Config { if c.Tracer == nil { c.Tracer = qlogger } else if qlogger != nil { - c.Tracer = logging.NewMultiplexedTracer(qlogger, c.Tracer) + origTracer := c.Tracer + c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + return logging.NewMultiplexedConnectionTracer( + qlogger(ctx, p, connID), + origTracer(ctx, p, connID), + ) + } } return c } - -type tracer struct { - logging.NullTracer - createNewConnTracer func() logging.ConnectionTracer -} - -var _ logging.Tracer = &tracer{} - -func newTracer(c func() logging.ConnectionTracer) logging.Tracer { - return &tracer{createNewConnTracer: c} -} - -func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer { - return t.createNewConnTracer() -} diff --git a/interface.go b/interface.go index 4d2a91976..7022cac4e 100644 --- a/interface.go +++ b/interface.go @@ -322,7 +322,7 @@ type Config struct { Allow0RTT bool // Enable QUIC datagram support (RFC 9221). EnableDatagrams bool - Tracer logging.Tracer + Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer } type ClientHelloInfo struct { diff --git a/internal/mocks/logging/tracer.go b/internal/mocks/logging/tracer.go index f269e7d65..56741ef29 100644 --- a/internal/mocks/logging/tracer.go +++ b/internal/mocks/logging/tracer.go @@ -5,7 +5,6 @@ package mocklogging import ( - context "context" net "net" reflect "reflect" @@ -73,17 +72,3 @@ func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) } - -// TracerForConnection mocks base method. -func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) logging.ConnectionTracer { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) - ret0, _ := ret[0].(logging.ConnectionTracer) - return ret0 -} - -// TracerForConnection indicates an expected call of TracerForConnection. -func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) -} diff --git a/interop/client/main.go b/interop/client/main.go index eaec9e67a..747619a09 100644 --- a/interop/client/main.go +++ b/interop/client/main.go @@ -21,7 +21,6 @@ import ( "github.com/quic-go/quic-go/internal/qtls" "github.com/quic-go/quic-go/interop/http09" "github.com/quic-go/quic-go/interop/utils" - "github.com/quic-go/quic-go/qlog" ) var errUnsupported = errors.New("unsupported test case") @@ -65,11 +64,7 @@ func runTestcase(testcase string) error { flag.Parse() urls := flag.Args() - getLogWriter, err := utils.GetQLOGWriter() - if err != nil { - return err - } - quicConf := &quic.Config{Tracer: qlog.NewTracer(getLogWriter)} + quicConf := &quic.Config{Tracer: utils.NewQLOGConnectionTracer} if testcase == "http3" { r := &http3.RoundTripper{ diff --git a/interop/server/main.go b/interop/server/main.go index 6c607c5c0..df7044624 100644 --- a/interop/server/main.go +++ b/interop/server/main.go @@ -13,7 +13,6 @@ import ( "github.com/quic-go/quic-go/internal/qtls" "github.com/quic-go/quic-go/interop/http09" "github.com/quic-go/quic-go/interop/utils" - "github.com/quic-go/quic-go/qlog" ) var tlsConf *tls.Config @@ -38,16 +37,10 @@ func main() { testcase := os.Getenv("TESTCASE") - getLogWriter, err := utils.GetQLOGWriter() - if err != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - // a quic.Config that doesn't do a Retry quicConf := &quic.Config{ RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" }, Allow0RTT: testcase == "zerortt", - Tracer: qlog.NewTracer(getLogWriter), + Tracer: utils.NewQLOGConnectionTracer, } cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key") if err != nil { diff --git a/interop/utils/logging.go b/interop/utils/logging.go index 3a6940ad9..30e3f663f 100644 --- a/interop/utils/logging.go +++ b/interop/utils/logging.go @@ -2,14 +2,17 @@ package utils import ( "bufio" + "context" "fmt" "io" "log" "os" "strings" + "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/logging" + "github.com/quic-go/quic-go/qlog" ) // GetSSLKeyLog creates a file for the TLS key log @@ -25,25 +28,23 @@ func GetSSLKeyLog() (io.WriteCloser, error) { return f, nil } -// GetQLOGWriter creates the QLOGDIR and returns the GetLogWriter callback -func GetQLOGWriter() (func(perspective logging.Perspective, connID []byte) io.WriteCloser, error) { +// NewQLOGConnectionTracer create a qlog file in QLOGDIR +func NewQLOGConnectionTracer(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { qlogDir := os.Getenv("QLOGDIR") if len(qlogDir) == 0 { - return nil, nil + return nil } if _, err := os.Stat(qlogDir); os.IsNotExist(err) { if err := os.MkdirAll(qlogDir, 0o666); err != nil { - return nil, fmt.Errorf("failed to create qlog dir %s: %s", qlogDir, err.Error()) + log.Fatalf("failed to create qlog dir %s: %v", qlogDir, err) } } - return func(_ logging.Perspective, connID []byte) io.WriteCloser { - path := fmt.Sprintf("%s/%x.qlog", strings.TrimRight(qlogDir, "/"), connID) - f, err := os.Create(path) - if err != nil { - log.Printf("Failed to create qlog file %s: %s", path, err.Error()) - return nil - } - log.Printf("Created qlog file: %s\n", path) - return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f) - }, nil + path := fmt.Sprintf("%s/%x.qlog", strings.TrimRight(qlogDir, "/"), connID) + f, err := os.Create(path) + if err != nil { + log.Printf("Failed to create qlog file %s: %s", path, err.Error()) + return nil + } + log.Printf("Created qlog file: %s\n", path) + return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID) } diff --git a/logging/interface.go b/logging/interface.go index efcef151e..2ce8582ec 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -3,7 +3,6 @@ package logging import ( - "context" "net" "time" @@ -101,12 +100,6 @@ type ShortHeader struct { // A Tracer traces events. type Tracer interface { - // TracerForConnection requests a new tracer for a connection. - // The ODCID is the original destination connection ID: - // The destination connection ID that the client used on the first Initial packet it sent on this connection. - // If nil is returned, tracing will be disabled for this connection. - TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer - SentPacket(net.Addr, *Header, ByteCount, []Frame) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) diff --git a/logging/mock_tracer_test.go b/logging/mock_tracer_test.go index 935d899bd..8526cd3a3 100644 --- a/logging/mock_tracer_test.go +++ b/logging/mock_tracer_test.go @@ -5,7 +5,6 @@ package logging import ( - context "context" net "net" reflect "reflect" @@ -72,17 +71,3 @@ func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) } - -// TracerForConnection mocks base method. -func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) ConnectionTracer { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) - ret0, _ := ret[0].(ConnectionTracer) - return ret0 -} - -// TracerForConnection indicates an expected call of TracerForConnection. -func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) -} diff --git a/logging/multiplex.go b/logging/multiplex.go index 8e85db494..672a5cdbd 100644 --- a/logging/multiplex.go +++ b/logging/multiplex.go @@ -1,7 +1,6 @@ package logging import ( - "context" "net" "time" ) @@ -23,16 +22,6 @@ func NewMultiplexedTracer(tracers ...Tracer) Tracer { return &tracerMultiplexer{tracers} } -func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer { - var connTracers []ConnectionTracer - for _, t := range m.tracers { - if ct := t.TracerForConnection(ctx, p, odcid); ct != nil { - connTracers = append(connTracers, ct) - } - } - return NewMultiplexedConnectionTracer(connTracers...) -} - func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { for _, t := range m.tracers { t.SentPacket(remote, hdr, size, frames) diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index 9b8515509..d22204999 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -1,7 +1,6 @@ package logging import ( - "context" "errors" "net" "time" @@ -37,46 +36,6 @@ var _ = Describe("Tracing", func() { tracer = NewMultiplexedTracer(tr1, tr2) }) - It("multiplexes the TracerForConnection call", func() { - ctx := context.Background() - connID := protocol.ParseConnectionID([]byte{1, 2, 3}) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) - tracer.TracerForConnection(ctx, PerspectiveClient, connID) - }) - - It("uses multiple connection tracers", func() { - ctx := context.Background() - ctr1 := NewMockConnectionTracer(mockCtrl) - ctr2 := NewMockConnectionTracer(mockCtrl) - connID := protocol.ParseConnectionID([]byte{1, 2, 3}) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr2) - tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID) - ctr1.EXPECT().LossTimerCanceled() - ctr2.EXPECT().LossTimerCanceled() - tr.LossTimerCanceled() - }) - - It("handles tracers that return a nil ConnectionTracer", func() { - ctx := context.Background() - ctr1 := NewMockConnectionTracer(mockCtrl) - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID) - tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID) - ctr1.EXPECT().LossTimerCanceled() - tr.LossTimerCanceled() - }) - - It("returns nil when all tracers return a nil ConnectionTracer", func() { - ctx := context.Background() - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) - Expect(tracer.TracerForConnection(ctx, PerspectiveClient, connID)).To(BeNil()) - }) - It("traces the PacketSent event", func() { remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} hdr := &Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} diff --git a/logging/null_tracer.go b/logging/null_tracer.go index 38052ae3b..de9703857 100644 --- a/logging/null_tracer.go +++ b/logging/null_tracer.go @@ -1,7 +1,6 @@ package logging import ( - "context" "net" "time" ) @@ -12,9 +11,6 @@ type NullTracer struct{} var _ Tracer = &NullTracer{} -func (n NullTracer) TracerForConnection(context.Context, Perspective, ConnectionID) ConnectionTracer { - return NullConnectionTracer{} -} func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {} func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) { } diff --git a/qlog/qlog.go b/qlog/qlog.go index bc2bb233d..4c480e260 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -2,7 +2,6 @@ package qlog import ( "bytes" - "context" "fmt" "io" "log" @@ -49,26 +48,6 @@ func init() { const eventChanSize = 50 -type tracer struct { - logging.NullTracer - - getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser -} - -var _ logging.Tracer = &tracer{} - -// NewTracer creates a new qlog tracer. -func NewTracer(getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser) logging.Tracer { - return &tracer{getLogWriter: getLogWriter} -} - -func (t *tracer) TracerForConnection(_ context.Context, p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { - if w := t.getLogWriter(p, odcid.Bytes()); w != nil { - return NewConnectionTracer(w, p, odcid) - } - return nil -} - type connectionTracer struct { mutex sync.Mutex diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index b4105c051..dc0d2dc18 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -2,7 +2,6 @@ package qlog import ( "bytes" - "context" "encoding/json" "errors" "io" @@ -51,17 +50,6 @@ type entry struct { } var _ = Describe("Tracing", func() { - Context("tracer", func() { - It("returns nil when there's no io.WriteCloser", func() { - t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil }) - Expect(t.TracerForConnection( - context.Background(), - logging.PerspectiveClient, - protocol.ParseConnectionID([]byte{1, 2, 3, 4}), - )).To(BeNil()) - }) - }) - It("stops writing when encountering an error", func() { buf := &bytes.Buffer{} t := NewConnectionTracer( @@ -88,9 +76,8 @@ var _ = Describe("Tracing", func() { BeforeEach(func() { buf = &bytes.Buffer{} - t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) }) - tracer = t.TracerForConnection( - context.Background(), + tracer = NewConnectionTracer( + nopWriteCloser(buf), logging.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), ) diff --git a/server.go b/server.go index a38a57848..f5b04549e 100644 --- a/server.go +++ b/server.go @@ -106,6 +106,8 @@ type baseServer struct { connQueue chan quicConn connQueueLen int32 // to be used as an atomic + tracer logging.Tracer + logger utils.Logger } @@ -212,7 +214,16 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear return tr.ListenEarly(tlsConf, config) } -func newServer(conn rawConn, connHandler packetHandlerManager, connIDGenerator ConnectionIDGenerator, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) { +func newServer( + conn rawConn, + connHandler packetHandlerManager, + connIDGenerator ConnectionIDGenerator, + tlsConf *tls.Config, + config *Config, + tracer logging.Tracer, + onClose func(), + acceptEarly bool, +) (*baseServer, error) { tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) if err != nil { return nil, err @@ -229,6 +240,7 @@ func newServer(conn rawConn, connHandler packetHandlerManager, connIDGenerator C running: make(chan struct{}), receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), newConn: newConnection, + tracer: tracer, logger: utils.DefaultLogger.WithPrefix("server"), acceptEarlyConns: acceptEarly, onClose: onClose, @@ -318,8 +330,8 @@ func (s *baseServer) handlePacket(p *receivedPacket) { case s.receivedPackets <- p: default: s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } } @@ -331,8 +343,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s if wire.IsVersionNegotiationPacket(p.data) { s.logger.Debugf("Dropping Version Negotiation packet.") - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -345,16 +357,16 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s if err != nil || !protocol.IsSupportedVersion(s.config.Versions, v) { if err != nil || p.Size() < protocol.MinUnknownVersionPacketSize { s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false } _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) if err != nil { // should never happen s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -366,8 +378,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s if wire.Is0RTTPacket(p.data) { if !s.acceptEarlyConns { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -378,16 +390,16 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s // The header will then be parsed again. hdr, _, _, err := wire.ParsePacket(p.data) if err != nil { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing packet: %s", err) return false } if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -397,8 +409,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s // There's little point in sending a Stateless Reset, since the client // might not have received the token yet. s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -416,8 +428,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { connID, err := wire.ParseConnectionID(p.data, 0) if err != nil { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError) } return false } @@ -430,8 +442,8 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { if q, ok := s.zeroRTTQueues[connID]; ok { if len(q.packets) >= protocol.Max0RTTQueueLen { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } return false } @@ -440,8 +452,8 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { } if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } return false } @@ -468,8 +480,8 @@ func (s *baseServer) cleanupZeroRTTQueues(now time.Time) { continue } for _, p := range q.packets { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } p.buffer.Release() } @@ -504,8 +516,8 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { p.buffer.Release() - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } return errors.New("too short connection ID") } @@ -585,19 +597,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro var conn quicConn tracingID := nextConnTracingID() if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) { - var tracer logging.ConnectionTracer - if s.config.Tracer != nil { - // Use the same connection ID that is passed to the client's GetLogWriter callback. - connID := hdr.DestConnectionID - if origDestConnID.Len() > 0 { - connID = origDestConnID - } - tracer = s.config.Tracer.TracerForConnection( - context.WithValue(context.Background(), ConnectionTracingKey, tracingID), - protocol.PerspectiveServer, - connID, - ) - } config := s.config if s.config.GetConfigForClient != nil { conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr}) @@ -607,6 +606,15 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro } config = populateConfig(conf) } + var tracer logging.ConnectionTracer + if config.Tracer != nil { + // Use the same connection ID that is passed to the client's GetLogWriter callback. + connID := hdr.DestConnectionID + if origDestConnID.Len() > 0 { + connID = origDestConnID + } + tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID) + } conn = s.newConn( newSendConn(s.conn, p.remoteAddr, p.info), s.connHandler, @@ -715,8 +723,8 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack // append the Retry integrity tag tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version) buf.Data = append(buf.Data, tag[:]...) - if s.config.Tracer != nil { - s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) + if s.tracer != nil { + s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) return err @@ -729,8 +737,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) data := p.data[:hdr.ParsedLen()+hdr.Length] extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version) if err != nil { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) } // don't return the error here. Just drop the packet. return nil @@ -738,8 +746,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) hdrLen := extHdr.ParsedLen() if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { // don't return the error here. Just drop the packet. - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) } return nil } @@ -792,8 +800,8 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han replyHdr.Log(s.logger) wire.LogFrame(s.logger, ccf, true) - if s.config.Tracer != nil { - s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) + if s.tracer != nil { + s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) return err @@ -803,8 +811,8 @@ func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest pro s.logger.Debugf("Client offered version %s, sending Version Negotiation", v) data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) - if s.config.Tracer != nil { - s.config.Tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions) + if s.tracer != nil { + s.tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions) } if _, err := s.conn.WritePacket(data, remote, oob); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) diff --git a/server_test.go b/server_test.go index e78356c81..bf0adfbb3 100644 --- a/server_test.go +++ b/server_test.go @@ -170,7 +170,7 @@ var _ = Describe("Server", func() { Context("server accepting connections that completed the handshake", func() { var ( - ln *Listener + tr *Transport serv *baseServer phm *MockPacketHandlerManager tracer *mocklogging.MockTracer @@ -178,8 +178,8 @@ var _ = Describe("Server", func() { BeforeEach(func() { tracer = mocklogging.NewMockTracer(mockCtrl) - var err error - ln, err = Listen(conn, tlsConf, &Config{Tracer: tracer}) + tr = &Transport{Conn: conn, Tracer: tracer} + ln, err := tr.Listen(tlsConf, nil) Expect(err).ToNot(HaveOccurred()) serv = ln.baseServer phm = NewMockPacketHandlerManager(mockCtrl) @@ -187,7 +187,7 @@ var _ = Describe("Server", func() { }) AfterEach(func() { - ln.Close() + tr.Close() }) Context("handling packets", func() { @@ -276,7 +276,6 @@ var _ = Describe("Server", func() { _, ok := fn() return ok }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})) conn := NewMockQUICConn(mockCtrl) serv.newConn = func( _ sendConn, @@ -478,7 +477,6 @@ var _ = Describe("Server", func() { return ok }), ) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) conn := NewMockQUICConn(mockCtrl) serv.newConn = func( @@ -537,7 +535,6 @@ var _ = Describe("Server", func() { _, ok := fn() return ok }).AnyTimes() - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() acceptConn := make(chan struct{}) var counter uint32 // to be used as an atomic, so we query it in Eventually @@ -665,7 +662,6 @@ var _ = Describe("Server", func() { _, ok := fn() return ok }).Times(protocol.MaxAcceptQueueSize) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize) var wg sync.WaitGroup wg.Add(protocol.MaxAcceptQueueSize) @@ -737,7 +733,6 @@ var _ = Describe("Server", func() { _, ok := fn() return ok }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) serv.handlePacket(p) // make sure there are no Write calls on the packet conn @@ -1062,6 +1057,7 @@ var _ = Describe("Server", func() { return ok }) done := make(chan struct{}) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer close(done) rejectHdr := parseHeader(b) @@ -1119,7 +1115,6 @@ var _ = Describe("Server", func() { _, ok := fn() return ok }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) serv.handleInitialImpl( &receivedPacket{buffer: getPacketBuffer()}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, @@ -1326,6 +1321,7 @@ var _ = Describe("Server", func() { Context("0-RTT", func() { var ( + tr *Transport serv *baseServer phm *MockPacketHandlerManager tracer *mocklogging.MockTracer @@ -1333,7 +1329,8 @@ var _ = Describe("Server", func() { BeforeEach(func() { tracer = mocklogging.NewMockTracer(mockCtrl) - ln, err := ListenEarly(conn, tlsConf, &Config{Tracer: tracer}) + tr = &Transport{Conn: conn, Tracer: tracer} + ln, err := tr.ListenEarly(tlsConf, nil) Expect(err).ToNot(HaveOccurred()) phm = NewMockPacketHandlerManager(mockCtrl) serv = ln.baseServer @@ -1342,7 +1339,7 @@ var _ = Describe("Server", func() { AfterEach(func() { phm.EXPECT().CloseServer().MaxTimes(1) - serv.Close() + tr.Close() }) It("passes packets to existing connections", func() { @@ -1425,7 +1422,6 @@ var _ = Describe("Server", func() { return conn } - tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any()) phm.EXPECT().Get(connID) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) diff --git a/transport.go b/transport.go index baeb592b7..153675da6 100644 --- a/transport.go +++ b/transport.go @@ -95,10 +95,10 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) return nil, errListenerAlreadySet } conf = populateServerConfig(conf) - if err := t.init(conf, true); err != nil { + if err := t.init(true); err != nil { return nil, err } - s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, false) + s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false) if err != nil { return nil, err } @@ -124,10 +124,10 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen return nil, errListenerAlreadySet } conf = populateServerConfig(conf) - if err := t.init(conf, true); err != nil { + if err := t.init(true); err != nil { return nil, err } - s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, true) + s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true) if err != nil { return nil, err } @@ -141,7 +141,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config return nil, err } conf = populateConfig(conf) - if err := t.init(conf, false); err != nil { + if err := t.init(false); err != nil { return nil, err } var onClose func() @@ -157,7 +157,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C return nil, err } conf = populateConfig(conf) - if err := t.init(conf, false); err != nil { + if err := t.init(false); err != nil { return nil, err } var onClose func() @@ -200,7 +200,7 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { // only print warnings about the UDP receive buffer size once var receiveBufferWarningOnce sync.Once -func (t *Transport) init(conf *Config, isServer bool) error { +func (t *Transport) init(isServer bool) error { t.initOnce.Do(func() { getMultiplexer().AddConn(t.Conn) @@ -210,7 +210,6 @@ func (t *Transport) init(conf *Config, isServer bool) error { return } - t.Tracer = conf.Tracer t.logger = utils.DefaultLogger // TODO: make this configurable t.conn = conn t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) diff --git a/transport_test.go b/transport_test.go index 19ef54f20..9441e4dec 100644 --- a/transport_test.go +++ b/transport_test.go @@ -61,7 +61,7 @@ var _ = Describe("Transport", func() { It("handles packets for different packet handlers on the same packet conn", func() { packetChan := make(chan packetToRead) tr := &Transport{Conn: newMockPacketConn(packetChan)} - tr.init(&Config{}, true) + tr.init(true) phm := NewMockPacketHandlerManager(mockCtrl) tr.handlerMap = phm connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) @@ -128,8 +128,9 @@ var _ = Describe("Transport", func() { tr := &Transport{ Conn: newMockPacketConn(packetChan), ConnectionIDLength: 10, + Tracer: tracer, } - tr.init(&Config{Tracer: tracer}, true) + tr.init(true) dropped := make(chan struct{}) tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }) packetChan <- packetToRead{ @@ -148,7 +149,7 @@ var _ = Describe("Transport", func() { tr := Transport{Conn: newMockPacketConn(packetChan)} defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}, true) + tr.init(true) tr.handlerMap = phm done := make(chan struct{}) @@ -166,7 +167,7 @@ var _ = Describe("Transport", func() { tr := Transport{Conn: newMockPacketConn(packetChan)} defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}, true) + tr.init(true) tr.handlerMap = phm tempErr := deadlineError{} @@ -188,7 +189,7 @@ var _ = Describe("Transport", func() { Conn: newMockPacketConn(packetChan), ConnectionIDLength: connID.Len(), } - tr.init(&Config{}, true) + tr.init(true) defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) tr.handlerMap = phm @@ -221,7 +222,7 @@ var _ = Describe("Transport", func() { connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) packetChan := make(chan packetToRead) tr := Transport{Conn: newMockPacketConn(packetChan)} - tr.init(&Config{}, true) + tr.init(true) defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) tr.handlerMap = phm @@ -257,7 +258,7 @@ var _ = Describe("Transport", func() { StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, ConnectionIDLength: connID.Len(), } - tr.init(&Config{}, true) + tr.init(true) defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) tr.handlerMap = phm