From 416298577d89b17911f88f0c257a4d71ec9188d6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 15 Sep 2017 17:27:04 +0700 Subject: [PATCH] only accept one version negotiation packet --- client.go | 17 +++++++++-------- client_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index b45ab3b8d..8770fe0f9 100644 --- a/client.go +++ b/client.go @@ -24,8 +24,9 @@ type client struct { handshakeChan <-chan handshakeEvent - versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version - versionNegotiated bool // has version negotiation completed yet + versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version + versionNegotiated bool // has version negotiation completed yet + receivedVersionNegotiationPacket bool tlsConf *tls.Config config *Config @@ -187,12 +188,10 @@ func (c *client) establishSecureConnection() error { errorChan := make(chan struct{}) go func() { // session.run() returns as soon as the session is closed - for { + runErr = c.session.run() + if runErr == errCloseSessionForNewVersion { + // run the new session runErr = c.session.run() - if runErr == errCloseSessionForNewVersion { - continue - } - break } close(errorChan) utils.Infof("Connection %x closed.", c.connectionID) @@ -278,7 +277,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { } // ignore delayed / duplicated version negotiation packets - if c.versionNegotiated && hdr.VersionFlag { + if (c.receivedVersionNegotiationPacket || c.versionNegotiated) && hdr.VersionFlag { return } @@ -315,6 +314,8 @@ func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error { } } + c.receivedVersionNegotiationPacket = true + newVersion := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) if newVersion == protocol.VersionUnsupported { return qerr.InvalidVersion diff --git a/client_test.go b/client_test.go index 402adfff5..f8fed2890 100644 --- a/client_test.go +++ b/client_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "net" + "sync/atomic" "time" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -355,6 +356,34 @@ var _ = Describe("Client", func() { Eventually(established).Should(BeClosed()) }) + It("only accepts one version negotiation packet", func() { + sessionCounter := uint32(0) + newClientSession = func( + _ connection, + _ string, + _ protocol.VersionNumber, + connectionID protocol.ConnectionID, + _ *tls.Config, + _ *Config, + negotiatedVersionsP []protocol.VersionNumber, + ) (packetHandler, <-chan handshakeEvent, error) { + atomic.AddUint32(&sessionCounter, 1) + return sess, nil, nil + } + go cl.establishSecureConnection() + Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1)) + newVersion := protocol.VersionNumber(77) + Expect(newVersion).ToNot(Equal(cl.version)) + Expect(config.Versions).To(ContainElement(newVersion)) + cl.handlePacket(nil, wire.ComposeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) + Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2)) + newVersion = protocol.VersionNumber(78) + Expect(newVersion).ToNot(Equal(cl.version)) + Expect(config.Versions).To(ContainElement(newVersion)) + cl.handlePacket(nil, wire.ComposeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) + Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2)) + }) + It("errors if no matching version is found", func() { cl.handlePacket(nil, wire.ComposeVersionNegotiation(0x1337, []protocol.VersionNumber{1})) Expect(cl.session.(*mockSession).closed).To(BeTrue())