From 84f3ec53438ace543dd51a6de040bdd702962014 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 20 Sep 2017 09:27:37 +0700 Subject: [PATCH] reject packets with the wrong connection ID in the client --- client.go | 8 ++++++- client_test.go | 58 ++++++++++++++++++++++++++++++++++---------------- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/client.go b/client.go index 6c44cce30..7818c2af1 100644 --- a/client.go +++ b/client.go @@ -38,6 +38,8 @@ type client struct { } var ( + // make it possible to mock connection ID generation in the tests + generateConnectionID = utils.GenerateConnectionID errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") ) @@ -82,7 +84,7 @@ func DialNonFWSecure( tlsConf *tls.Config, config *Config, ) (NonFWSession, error) { - connID, err := utils.GenerateConnectionID() + connID, err := generateConnectionID() if err != nil { return nil, err } @@ -257,6 +259,10 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { if hdr.TruncateConnectionID && !c.config.RequestConnectionIDTruncation { return } + // reject packets with the wrong connection ID + if !hdr.TruncateConnectionID && hdr.ConnectionID != c.connectionID { + return + } hdr.Raw = packet[:len(packet)-r.Len()] c.mutex.Lock() diff --git a/client_test.go b/client_test.go index 8403101af..3b5dd13fa 100644 --- a/client_test.go +++ b/client_test.go @@ -27,6 +27,18 @@ var _ = Describe("Client", func() { originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error) ) + // generate a packet sent by the server that accepts the QUIC version suggested by the client + acceptClientVersionPacket := func(connID protocol.ConnectionID) []byte { + b := &bytes.Buffer{} + err := (&wire.PublicHeader{ + ConnectionID: connID, + PacketNumber: 1, + PacketNumberLen: 1, + }).Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + return b.Bytes() + } + BeforeEach(func() { originalClientSessConstructor = newClientSession Eventually(areSessionsRunning).Should(BeFalse()) @@ -62,7 +74,7 @@ var _ = Describe("Client", func() { }) Context("Dialing", func() { - var acceptClientVersionPacket []byte + var origGenerateConnectionID func() (protocol.ConnectionID, error) BeforeEach(func() { newClientSession = func( @@ -75,22 +87,20 @@ var _ = Describe("Client", func() { _ []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { Expect(conn.Write([]byte("fake CHLO"))).To(Succeed()) - // Expect(err).ToNot(HaveOccurred()) return sess, sess.handshakeChan, nil } - // accept the QUIC version suggested by the client - b := &bytes.Buffer{} - err := (&wire.PublicHeader{ - ConnectionID: 0x1337, - PacketNumber: 1, - PacketNumberLen: 1, - }).Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) - Expect(err).ToNot(HaveOccurred()) - acceptClientVersionPacket = b.Bytes() + origGenerateConnectionID = generateConnectionID + generateConnectionID = func() (protocol.ConnectionID, error) { + return cl.connectionID, nil + } + }) + + AfterEach(func() { + generateConnectionID = origGenerateConnectionID }) It("dials non-forward-secure", func(done Done) { - packetConn.dataToRead = acceptClientVersionPacket + packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID) dialed := make(chan struct{}) go func() { defer GinkgoRecover() @@ -118,7 +128,7 @@ var _ = Describe("Client", func() { if err != nil { return } - _, err = server.WriteToUDP(acceptClientVersionPacket, clientAddr) + _, err = server.WriteToUDP(acceptClientVersionPacket(cl.connectionID), clientAddr) Expect(err).ToNot(HaveOccurred()) } }() @@ -138,7 +148,7 @@ var _ = Describe("Client", func() { }) It("Dial only returns after the handshake is complete", func(done Done) { - packetConn.dataToRead = acceptClientVersionPacket + packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID) dialed := make(chan struct{}) go func() { defer GinkgoRecover() @@ -219,7 +229,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs while waiting for the connection to become secure", func(done Done) { testErr := errors.New("early handshake error") - packetConn.dataToRead = acceptClientVersionPacket + packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID) go func() { defer GinkgoRecover() _, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) @@ -231,7 +241,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs while waiting for the handshake to complete", func(done Done) { testErr := errors.New("late handshake error") - packetConn.dataToRead = acceptClientVersionPacket + packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID) go func() { _, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) Expect(dialErr).To(MatchError(testErr)) @@ -307,7 +317,7 @@ var _ = Describe("Client", func() { Expect(newVersion).ToNot(Equal(cl.version)) Expect(config.Versions).To(ContainElement(newVersion)) packetConn.dataToRead = wire.ComposeVersionNegotiation( - 0x1337, + cl.connectionID, []protocol.VersionNumber{newVersion}, ) sessionChan := make(chan *mockSession) @@ -324,7 +334,7 @@ var _ = Describe("Client", func() { negotiatedVersions = negotiatedVersionsP // make the server accept the new version if len(negotiatedVersionsP) > 0 { - packetConn.dataToRead = acceptClientVersionPacket + packetConn.dataToRead = acceptClientVersionPacket(connectionID) } sess := &mockSession{ connectionID: connectionID, @@ -440,6 +450,18 @@ var _ = Describe("Client", func() { Expect(sess.closed).To(BeFalse()) }) + It("ignores packets with the wrong connection ID", func() { + buf := &bytes.Buffer{} + (&wire.PublicHeader{ + ConnectionID: cl.connectionID + 1, + PacketNumber: 1, + PacketNumberLen: 1, + }).Write(buf, protocol.VersionWhatever, protocol.PerspectiveServer) + cl.handlePacket(addr, buf.Bytes()) + Expect(sess.packetCount).To(BeZero()) + Expect(sess.closed).To(BeFalse()) + }) + It("creates new sessions with the right parameters", func(done Done) { c := make(chan struct{}) var cconn connection