diff --git a/client.go b/client.go index c4fae8a6..4736fcc3 100644 --- a/client.go +++ b/client.go @@ -289,29 +289,16 @@ func (c *client) establishSecureConnection(ctx context.Context) error { } func (c *client) handlePacket(p *receivedPacket) { - if err := c.handlePacketImpl(p); err != nil { - c.logger.Errorf("error handling packet: %s", err) - } -} - -func (c *client) handlePacketImpl(p *receivedPacket) error { - // handle Version Negotiation Packets if p.hdr.IsVersionNegotiation() { - c.mutex.Lock() - err := c.handleVersionNegotiationPacket(p.hdr) - if err != nil { - c.session.destroy(err) - } - c.mutex.Unlock() - // version negotiation packets have no payload - return err + go c.handleVersionNegotiationPacket(p.hdr) + return } if p.hdr.Type == protocol.PacketTypeRetry { c.mutex.Lock() c.handleRetryPacket(p.hdr) c.mutex.Unlock() - return nil + return } // this is the first packet we are receiving @@ -321,29 +308,32 @@ func (c *client) handlePacketImpl(p *receivedPacket) error { } c.session.handlePacket(p) - return nil } -func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { +func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) { + c.mutex.Lock() + defer c.mutex.Unlock() + // ignore delayed / duplicated version negotiation packets if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() { - c.logger.Debugf("Received a delayed Version Negotiation Packet.") - return nil + c.logger.Debugf("Received a delayed Version Negotiation packet.") + return } for _, v := range hdr.SupportedVersions { if v == c.version { - // the version negotiation packet contains the version that we offered - // this might be a packet sent by an attacker (or by a terribly broken server implementation) - // ignore it - return nil + // The Version Negotiation packet contains the version that we offered. + // This might be a packet sent by an attacker (or by a terribly broken server implementation). + return } } - c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions) + c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions) newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) if !ok { - return qerr.InvalidVersion + c.session.destroy(qerr.InvalidVersion) + c.logger.Debugf("No compatible version found.") + return } c.receivedVersionNegotiationPacket = true c.negotiatedVersions = hdr.SupportedVersions @@ -354,7 +344,6 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID) c.session.destroy(errCloseSessionForNewVersion) - return nil } func (c *client) handleRetryPacket(hdr *wire.Header) { diff --git a/client_test.go b/client_test.go index 3d07e2a8..652f4a39 100644 --- a/client_test.go +++ b/client_test.go @@ -679,33 +679,36 @@ var _ = Describe("Client", func() { sess.EXPECT().handlePacket(gomock.Any()) cl.session = sess cl.config = &Config{} - err := cl.handlePacketImpl(&receivedPacket{ + cl.handlePacket(&receivedPacket{ hdr: &wire.Header{ DestConnectionID: connID, SrcConnectionID: connID, Version: cl.version, }, }) - Expect(err).ToNot(HaveOccurred()) - Expect(cl.versionNegotiated.Get()).To(BeTrue()) + Eventually(cl.versionNegotiated.Get()).Should(BeTrue()) }) It("errors if no matching version is found", func() { sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().destroy(qerr.InvalidVersion) + done := make(chan struct{}) + sess.EXPECT().destroy(qerr.InvalidVersion).Do(func(error) { close(done) }) cl.session = sess cl.config = &Config{Versions: protocol.SupportedVersions} cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1})) + Eventually(done).Should(BeClosed()) }) It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() { sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().destroy(qerr.InvalidVersion) + done := make(chan struct{}) + sess.EXPECT().destroy(qerr.InvalidVersion).Do(func(error) { close(done) }) cl.session = sess v := protocol.VersionNumber(1234) Expect(v).ToNot(Equal(cl.version)) cl.config = &Config{Versions: protocol.SupportedVersions} cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{v})) + Eventually(done).Should(BeClosed()) }) It("changes to the version preferred by the quic.Config", func() { @@ -713,11 +716,15 @@ var _ = Describe("Client", func() { cl.packetHandlers = phm sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().destroy(errCloseSessionForNewVersion) + destroyed := make(chan struct{}) + sess.EXPECT().destroy(errCloseSessionForNewVersion).Do(func(error) { + close(destroyed) + }) cl.session = sess versions := []protocol.VersionNumber{1234, 4321} cl.config = &Config{Versions: versions} cl.handlePacket(composeVersionNegotiationPacket(connID, versions)) + Eventually(destroyed).Should(BeClosed()) Expect(cl.version).To(Equal(protocol.VersionNumber(1234))) })