From 7003450d2bcb0a919609177a0239aab97b709f1e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 15 Sep 2017 16:21:29 +0700 Subject: [PATCH] fix version negotiation --- client.go | 94 +++++++++++++++++++++++------------------- client_test.go | 110 +++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 140 insertions(+), 64 deletions(-) diff --git a/client.go b/client.go index 021d2d71..b45ab3b8 100644 --- a/client.go +++ b/client.go @@ -17,18 +17,18 @@ import ( ) type client struct { - mutex sync.Mutex - listenErr error + mutex sync.Mutex conn connection hostname string - errorChan chan struct{} handshakeChan <-chan handshakeEvent - tlsConf *tls.Config - config *Config - versionNegotiated bool // has version negotiation completed yet + versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version + versionNegotiated bool // has version negotiation completed yet + + tlsConf *tls.Config + config *Config connectionID protocol.ConnectionID version protocol.VersionNumber @@ -100,23 +100,21 @@ func DialNonFWSecure( clientConfig := populateClientConfig(config) c := &client{ - conn: &conn{pconn: pconn, currentAddr: remoteAddr}, - connectionID: connID, - hostname: hostname, - tlsConf: tlsConf, - config: clientConfig, - version: clientConfig.Versions[0], - errorChan: make(chan struct{}), - } - - err = c.createNewSession(nil) - if err != nil { - return nil, err + conn: &conn{pconn: pconn, currentAddr: remoteAddr}, + connectionID: connID, + hostname: hostname, + tlsConf: tlsConf, + config: clientConfig, + version: clientConfig.Versions[0], + versionNegotiationChan: make(chan struct{}), } utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) - return c.session.(NonFWSession), c.establishSecureConnection() + if err := c.establishSecureConnection(); err != nil { + return nil, err + } + return c.session.(NonFWSession), nil } // Dial establishes a new QUIC connection to a server using a net.PacketConn. @@ -132,8 +130,7 @@ func Dial( if err != nil { return nil, err } - err = sess.WaitUntilHandshakeComplete() - if err != nil { + if err := sess.WaitUntilHandshakeComplete(); err != nil { return nil, err } return sess, nil @@ -181,11 +178,37 @@ func populateClientConfig(config *Config) *Config { // establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure) func (c *client) establishSecureConnection() error { + if err := c.createNewSession(nil); err != nil { + return err + } go c.listen() + var runErr error + errorChan := make(chan struct{}) + go func() { + // session.run() returns as soon as the session is closed + for { + runErr = c.session.run() + if runErr == errCloseSessionForNewVersion { + continue + } + break + } + close(errorChan) + utils.Infof("Connection %x closed.", c.connectionID) + c.conn.Close() + }() + + // wait until the server accepts the QUIC version (or an error occurs) select { - case <-c.errorChan: - return c.listenErr + case <-errorChan: + return runErr + case <-c.versionNegotiationChan: + } + + select { + case <-errorChan: + return runErr case ev := <-c.handshakeChan: if ev.err != nil { return ev.err @@ -263,6 +286,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { // if the server doesn't send a version negotiation packet, it supports the suggested version if !hdr.VersionFlag && !c.versionNegotiated { c.versionNegotiated = true + close(c.versionNegotiationChan) } if hdr.VersionFlag { @@ -298,7 +322,6 @@ func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error { // switch to negotiated version c.version = newVersion - c.versionNegotiated = true var err error c.connectionID, err = utils.GenerateConnectionID() if err != nil { @@ -306,7 +329,10 @@ func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error { } utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID) - c.session.Close(errCloseSessionForNewVersion) + // create a new session and close the old one + // the new session must be created first to update client member variables + oldSession := c.session + defer oldSession.Close(errCloseSessionForNewVersion) return c.createNewSession(hdr.SupportedVersions) } @@ -321,21 +347,5 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e c.config, negotiatedVersions, ) - if err != nil { - return err - } - - go func() { - // session.run() returns as soon as the session is closed - err := c.session.run() - if err == errCloseSessionForNewVersion { - return - } - c.listenErr = err - close(c.errorChan) - - utils.Infof("Connection %x closed.", c.connectionID) - c.conn.Close() - }() - return nil + return err } diff --git a/client_test.go b/client_test.go index 0f396482..402adfff 100644 --- a/client_test.go +++ b/client_test.go @@ -45,7 +45,7 @@ var _ = Describe("Client", func() { session: sess, version: protocol.SupportedVersions[0], conn: &conn{pconn: packetConn, currentAddr: addr}, - errorChan: make(chan struct{}), + versionNegotiationChan: make(chan struct{}), } }) @@ -61,9 +61,11 @@ var _ = Describe("Client", func() { }) Context("Dialing", func() { + var acceptClientVersionPacket []byte + BeforeEach(func() { newClientSession = func( - _ connection, + conn connection, _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, @@ -71,11 +73,23 @@ var _ = Describe("Client", func() { _ *Config, _ []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { + Expect(conn.Write([]byte("fake CHLO"))).To(Succeed()) + // Expect(err).ToNot(HaveOccurred()) return sess, sess.handshakeChan, nil } + // accept the QUIC version suggested by the client + b := &bytes.Buffer{} + err := (&wire.PublicHeader{ + ConnectionID: 0x1337, + PacketNumber: 1, + PacketNumberLen: 1, + }).Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + acceptClientVersionPacket = b.Bytes() }) It("dials non-forward-secure", func(done Done) { + packetConn.dataToRead = acceptClientVersionPacket dialed := make(chan struct{}) go func() { defer GinkgoRecover() @@ -91,10 +105,27 @@ var _ = Describe("Client", func() { }) It("dials a non-forward-secure address", func(done Done) { + serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + Expect(err).ToNot(HaveOccurred()) + server, err := net.ListenUDP("udp", serverAddr) + Expect(err).ToNot(HaveOccurred()) + defer server.Close() + go func() { + defer GinkgoRecover() + for { + _, clientAddr, err := server.ReadFromUDP(make([]byte, 200)) + if err != nil { + return + } + _, err = server.WriteToUDP(acceptClientVersionPacket, clientAddr) + Expect(err).ToNot(HaveOccurred()) + } + }() + dialed := make(chan struct{}) go func() { defer GinkgoRecover() - s, err := DialAddrNonFWSecure("localhost:18901", nil, config) + s, err := DialAddrNonFWSecure(server.LocalAddr().String(), nil, config) Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) close(dialed) @@ -106,6 +137,7 @@ var _ = Describe("Client", func() { }) It("Dial only returns after the handshake is complete", func(done Done) { + packetConn.dataToRead = acceptClientVersionPacket dialed := make(chan struct{}) go func() { defer GinkgoRecover() @@ -173,12 +205,23 @@ var _ = Describe("Client", func() { close(done) }) - It("returns an error that occurs while waiting for the connection to become secure", func(done Done) { + It("returns an error that occurs during version negotiation", func(done Done) { testErr := errors.New("early handshake error") - var dialErr error go func() { defer GinkgoRecover() - _, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + _, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + Expect(dialErr).To(MatchError(testErr)) + close(done) + }() + sess.Close(testErr) + }) + + It("returns an error that occurs while waiting for the connection to become secure", func(done Done) { + testErr := errors.New("early handshake error") + packetConn.dataToRead = acceptClientVersionPacket + go func() { + defer GinkgoRecover() + _, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) Expect(dialErr).To(MatchError(testErr)) close(done) }() @@ -187,9 +230,9 @@ var _ = Describe("Client", func() { It("returns an error that occurs while waiting for the handshake to complete", func(done Done) { testErr := errors.New("late handshake error") - var dialErr error + packetConn.dataToRead = acceptClientVersionPacket go func() { - _, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + _, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) Expect(dialErr).To(MatchError(testErr)) close(done) }() @@ -254,10 +297,20 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) cl.handlePacket(nil, b.Bytes()) Expect(cl.versionNegotiated).To(BeTrue()) + Expect(cl.versionNegotiationChan).To(BeClosed()) }) It("changes the version after receiving a version negotiation packet", func() { var negotiatedVersions []protocol.VersionNumber + newVersion := protocol.VersionNumber(77) + Expect(newVersion).ToNot(Equal(cl.version)) + Expect(config.Versions).To(ContainElement(newVersion)) + packetConn.dataToRead = wire.ComposeVersionNegotiation( + 0x1337, + []protocol.VersionNumber{newVersion}, + ) + sessionChan := make(chan *mockSession) + handshakeChan := make(chan handshakeEvent) newClientSession = func( _ connection, _ string, @@ -268,25 +321,38 @@ var _ = Describe("Client", func() { negotiatedVersionsP []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { negotiatedVersions = negotiatedVersionsP - return &mockSession{ + // make the server accept the new version + if len(negotiatedVersionsP) > 0 { + packetConn.dataToRead = acceptClientVersionPacket + } + sess := &mockSession{ connectionID: connectionID, - }, nil, nil + stopRunLoop: make(chan struct{}), + } + sessionChan <- sess + return sess, handshakeChan, nil } - newVersion := protocol.VersionNumber(77) - Expect(config.Versions).To(ContainElement(newVersion)) - Expect(newVersion).ToNot(Equal(cl.version)) - Expect(sess.packetCount).To(BeZero()) - cl.connectionID = 0x1337 - cl.handlePacket(nil, wire.ComposeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) - Expect(cl.version).To(Equal(newVersion)) - Expect(cl.versionNegotiated).To(BeTrue()) - // it swapped the sessions - // Expect(cl.session).ToNot(Equal(sess)) - Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID + established := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := cl.establishSecureConnection() + Expect(err).ToNot(HaveOccurred()) + close(established) + }() + var firstSession, secondSession *mockSession + Eventually(sessionChan).Should(Receive(&firstSession)) + Eventually(sessionChan).Should(Receive(&secondSession)) // it didn't pass the version negoation packet to the old session (since it has no payload) - Expect(sess.packetCount).To(BeZero()) + Expect(firstSession.packetCount).To(BeZero()) + Eventually(func() bool { return firstSession.closed }).Should(BeTrue()) + Expect(firstSession.closeReason).To(Equal(errCloseSessionForNewVersion)) + Consistently(func() bool { return secondSession.closed }).Should(BeFalse()) + Expect(cl.connectionID).ToNot(BeEquivalentTo(0x1337)) Expect(negotiatedVersions).To(Equal([]protocol.VersionNumber{newVersion})) + + handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} + Eventually(established).Should(BeClosed()) }) It("errors if no matching version is found", func() {