From a2988bce8dcd141df4d0e765aeedb57bb43ba30d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 2 Oct 2017 12:10:44 +0700 Subject: [PATCH] send versions in ClientHello and EncryptedExtensions TLS parameters It only sends the right values now, but doesn't yet perform any validation. --- client.go | 8 +++++--- client_test.go | 15 +++++++++++++-- internal/handshake/crypto_setup_tls.go | 6 ++++-- internal/handshake/crypto_setup_tls_test.go | 1 + internal/handshake/tl_extension_handler_server.go | 15 ++++++++++++--- .../handshake/tls_extension_handler_client.go | 15 +++++++++++---- .../tls_extension_handler_client_test.go | 7 ++++++- .../tls_extension_handler_server_test.go | 9 ++++++--- session.go | 10 +++++++--- session_test.go | 1 + 10 files changed, 66 insertions(+), 21 deletions(-) diff --git a/client.go b/client.go index c6c8b2b4c..f74b9ab5c 100644 --- a/client.go +++ b/client.go @@ -181,7 +181,7 @@ func populateClientConfig(config *Config) *Config { // establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure) func (c *client) establishSecureConnection() error { - if err := c.createNewSession(nil); err != nil { + if err := c.createNewSession(c.version, nil); err != nil { return err } go c.listen() @@ -332,6 +332,7 @@ func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error { } // switch to negotiated version + initialVersion := c.version c.version = newVersion var err error c.connectionID, err = utils.GenerateConnectionID() @@ -344,10 +345,10 @@ func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error { // the new session must be created first to update client member variables oldSession := c.session defer oldSession.Close(errCloseSessionForNewVersion) - return c.createNewSession(hdr.SupportedVersions) + return c.createNewSession(initialVersion, hdr.SupportedVersions) } -func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { +func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) error { var err error c.session, c.handshakeChan, err = newClientSession( c.conn, @@ -356,6 +357,7 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e c.connectionID, c.tlsConf, c.config, + initialVersion, negotiatedVersions, ) return err diff --git a/client_test.go b/client_test.go index 0ba3bb149..9bf60d21a 100644 --- a/client_test.go +++ b/client_test.go @@ -24,7 +24,7 @@ var _ = Describe("Client", func() { packetConn *mockPacketConn addr net.Addr - originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error) + originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error) ) // generate a packet sent by the server that accepts the QUIC version suggested by the client @@ -84,6 +84,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *tls.Config, _ *Config, + _ protocol.VersionNumber, _ []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { Expect(conn.Write([]byte("fake CHLO"))).To(Succeed()) @@ -173,6 +174,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *tls.Config, _ *Config, + _ protocol.VersionNumber, _ []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { remoteAddrChan <- conn.RemoteAddr().String() @@ -199,6 +201,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *tls.Config, _ *Config, + _ protocol.VersionNumber, _ []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { hostnameChan <- h @@ -288,6 +291,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *tls.Config, _ *Config, + _ protocol.VersionNumber, _ []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { return nil, nil, testErr @@ -312,6 +316,7 @@ var _ = Describe("Client", func() { }) It("changes the version after receiving a version negotiation packet", func() { + var initialVersion protocol.VersionNumber var negotiatedVersions []protocol.VersionNumber newVersion := protocol.VersionNumber(77) Expect(newVersion).ToNot(Equal(cl.version)) @@ -329,8 +334,10 @@ var _ = Describe("Client", func() { connectionID protocol.ConnectionID, _ *tls.Config, _ *Config, + initialVersionP protocol.VersionNumber, negotiatedVersionsP []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { + initialVersion = initialVersionP negotiatedVersions = negotiatedVersionsP // make the server accept the new version if len(negotiatedVersionsP) > 0 { @@ -351,6 +358,7 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) close(established) }() + actualInitialVersion := cl.version var firstSession, secondSession *mockSession Eventually(sessionChan).Should(Receive(&firstSession)) Eventually(sessionChan).Should(Receive(&secondSession)) @@ -361,6 +369,7 @@ var _ = Describe("Client", func() { Consistently(func() bool { return secondSession.closed }).Should(BeFalse()) Expect(cl.connectionID).ToNot(BeEquivalentTo(0x1337)) Expect(negotiatedVersions).To(Equal([]protocol.VersionNumber{newVersion})) + Expect(initialVersion).To(Equal(actualInitialVersion)) handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} Eventually(established).Should(BeClosed()) @@ -375,7 +384,8 @@ var _ = Describe("Client", func() { connectionID protocol.ConnectionID, _ *tls.Config, _ *Config, - negotiatedVersionsP []protocol.VersionNumber, + _ protocol.VersionNumber, + _ []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { atomic.AddUint32(&sessionCounter, 1) return sess, nil, nil @@ -475,6 +485,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *tls.Config, configP *Config, + _ protocol.VersionNumber, _ []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { cconn = connP diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 64478865d..ae3317ada 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -40,6 +40,7 @@ func NewCryptoSetupTLSServer( tlsConfig *tls.Config, transportParams *TransportParameters, aeadChanged chan<- protocol.EncryptionLevel, + supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, ) (CryptoSetup, ParamsNegotiator, error) { mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer) @@ -54,7 +55,7 @@ func NewCryptoSetupTLSServer( nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), keyDerivation: crypto.DeriveAESKeys, aeadChanged: aeadChanged, - extensionHandler: newExtensionHandlerServer(params), + extensionHandler: newExtensionHandlerServer(params, supportedVersions), }, params, nil } @@ -64,6 +65,7 @@ func NewCryptoSetupTLSClient( tlsConfig *tls.Config, transportParams *TransportParameters, aeadChanged chan<- protocol.EncryptionLevel, + initialVersion protocol.VersionNumber, version protocol.VersionNumber, ) (CryptoSetup, ParamsNegotiator, error) { mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient) @@ -79,7 +81,7 @@ func NewCryptoSetupTLSClient( nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), keyDerivation: crypto.DeriveAESKeys, aeadChanged: aeadChanged, - extensionHandler: newExtensionHandlerClient(params), + extensionHandler: newExtensionHandlerClient(params, initialVersion, version), }, params, nil } diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 2524bb5a6..583fdc6eb 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -44,6 +44,7 @@ var _ = Describe("TLS Crypto Setup", func() { testdata.GetTLSConfig(), &TransportParameters{}, aeadChanged, + nil, protocol.VersionTLS, ) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/handshake/tl_extension_handler_server.go b/internal/handshake/tl_extension_handler_server.go index ee35f9499..3687baaac 100644 --- a/internal/handshake/tl_extension_handler_server.go +++ b/internal/handshake/tl_extension_handler_server.go @@ -12,12 +12,17 @@ import ( type extensionHandlerServer struct { params *paramsNegotiator + + supportedVersions []protocol.VersionNumber } var _ mint.AppExtensionHandler = &extensionHandlerServer{} -func newExtensionHandlerServer(params *paramsNegotiator) *extensionHandlerServer { - return &extensionHandlerServer{params: params} +func newExtensionHandlerServer(params *paramsNegotiator, supportedVersions []protocol.VersionNumber) *extensionHandlerServer { + return &extensionHandlerServer{ + params: params, + supportedVersions: supportedVersions, + } } func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.ExtensionList) error { @@ -29,8 +34,12 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi h.params.GetTransportParameters(), transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)}, ) + supportedVersions := make([]uint32, len(h.supportedVersions)) + for i, v := range h.supportedVersions { + supportedVersions[i] = uint32(v) + } data, err := syntax.Marshal(encryptedExtensionsTransportParameters{ - SupportedVersions: []uint32{uint32(protocol.VersionTLS)}, + SupportedVersions: supportedVersions, Parameters: transportParams, }) if err != nil { diff --git a/internal/handshake/tls_extension_handler_client.go b/internal/handshake/tls_extension_handler_client.go index bd0d5ce3c..3a6a07523 100644 --- a/internal/handshake/tls_extension_handler_client.go +++ b/internal/handshake/tls_extension_handler_client.go @@ -11,12 +11,19 @@ import ( type extensionHandlerClient struct { params *paramsNegotiator + + initialVersion protocol.VersionNumber + version protocol.VersionNumber } var _ mint.AppExtensionHandler = &extensionHandlerClient{} -func newExtensionHandlerClient(params *paramsNegotiator) *extensionHandlerClient { - return &extensionHandlerClient{params: params} +func newExtensionHandlerClient(params *paramsNegotiator, initialVersion, version protocol.VersionNumber) *extensionHandlerClient { + return &extensionHandlerClient{ + params: params, + initialVersion: initialVersion, + version: version, + } } func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.ExtensionList) error { @@ -25,8 +32,8 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi } data, err := syntax.Marshal(clientHelloTransportParameters{ - NegotiatedVersion: uint32(protocol.VersionTLS), - InitialVersion: uint32(protocol.VersionTLS), + NegotiatedVersion: uint32(h.version), + InitialVersion: uint32(h.initialVersion), Parameters: h.params.GetTransportParameters(), }) if err != nil { diff --git a/internal/handshake/tls_extension_handler_client_test.go b/internal/handshake/tls_extension_handler_client_test.go index 25a18807d..ffd3d8a72 100644 --- a/internal/handshake/tls_extension_handler_client_test.go +++ b/internal/handshake/tls_extension_handler_client_test.go @@ -6,6 +6,7 @@ import ( "github.com/bifurcation/mint" "github.com/bifurcation/mint/syntax" + "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -16,7 +17,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() { BeforeEach(func() { pn := ¶msNegotiator{} - handler = newExtensionHandlerClient(pn) + handler = newExtensionHandlerClient(pn, protocol.VersionWhatever, protocol.VersionWhatever) el = make(mint.ExtensionList, 0) }) @@ -32,6 +33,8 @@ var _ = Describe("TLS Extension Handler, for the client", func() { }) It("adds TransportParameters to the ClientHello", func() { + handler.initialVersion = 13 + handler.version = 37 err := handler.Send(mint.HandshakeTypeClientHello, &el) Expect(err).ToNot(HaveOccurred()) Expect(el).To(HaveLen(1)) @@ -41,6 +44,8 @@ var _ = Describe("TLS Extension Handler, for the client", func() { chtp := &clientHelloTransportParameters{} _, err = syntax.Unmarshal(ext.data, chtp) Expect(err).ToNot(HaveOccurred()) + Expect(chtp.InitialVersion).To(BeEquivalentTo(13)) + Expect(chtp.NegotiatedVersion).To(BeEquivalentTo(37)) }) }) diff --git a/internal/handshake/tls_extension_handler_server_test.go b/internal/handshake/tls_extension_handler_server_test.go index 3395b67e2..c8bc463ea 100644 --- a/internal/handshake/tls_extension_handler_server_test.go +++ b/internal/handshake/tls_extension_handler_server_test.go @@ -5,6 +5,7 @@ import ( "github.com/bifurcation/mint" "github.com/bifurcation/mint/syntax" + "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -15,7 +16,7 @@ var _ = Describe("TLS Extension Handler, for the server", func() { BeforeEach(func() { pn := ¶msNegotiator{} - handler = newExtensionHandlerServer(pn) + handler = newExtensionHandlerServer(pn, nil) el = make(mint.ExtensionList, 0) }) @@ -31,15 +32,17 @@ var _ = Describe("TLS Extension Handler, for the server", func() { }) It("adds TransportParameters to the EncryptedExtensions message", func() { + handler.supportedVersions = []protocol.VersionNumber{13, 37, 42} err := handler.Send(mint.HandshakeTypeEncryptedExtensions, &el) Expect(err).ToNot(HaveOccurred()) Expect(el).To(HaveLen(1)) ext := &tlsExtensionBody{} found := el.Find(ext) Expect(found).To(BeTrue()) - chtp := &encryptedExtensionsTransportParameters{} - _, err = syntax.Unmarshal(ext.data, chtp) + eetp := &encryptedExtensionsTransportParameters{} + _, err = syntax.Unmarshal(ext.data, eetp) Expect(err).ToNot(HaveOccurred()) + Expect(eetp.SupportedVersions).To(Equal([]uint32{13, 37, 42})) }) }) diff --git a/session.go b/session.go index 2c46d3180..c5b9e4e9a 100644 --- a/session.go +++ b/session.go @@ -134,7 +134,7 @@ func newSession( version: v, config: config, } - return s.setup(sCfg, "", tlsConf, nil) + return s.setup(sCfg, "", tlsConf, v, nil) } // declare this as a variable, such that we can it mock it in the tests @@ -145,7 +145,8 @@ var newClientSession = func( connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, - negotiatedVersions []protocol.VersionNumber, + initialVersion protocol.VersionNumber, // needed for validation of the version negotaion over TLS + negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton ) (packetHandler, <-chan handshakeEvent, error) { s := &session{ conn: conn, @@ -154,13 +155,14 @@ var newClientSession = func( version: v, config: config, } - return s.setup(nil, hostname, tlsConf, negotiatedVersions) + return s.setup(nil, hostname, tlsConf, v, negotiatedVersions) } func (s *session) setup( scfg *handshake.ServerConfig, hostname string, tlsConf *tls.Config, + initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { aeadChanged := make(chan protocol.EncryptionLevel, 2) @@ -198,6 +200,7 @@ func (s *session) setup( tlsConf, transportParams, aeadChanged, + s.config.Versions, s.version, ) } else { @@ -219,6 +222,7 @@ func (s *session) setup( tlsConf, transportParams, aeadChanged, + initialVersion, s.version, ) } else { diff --git a/session_test.go b/session_test.go index f02b3cdad..39be20b92 100644 --- a/session_test.go +++ b/session_test.go @@ -1743,6 +1743,7 @@ var _ = Describe("Client Session", func() { 0, nil, populateClientConfig(&Config{}), + protocol.VersionWhatever, nil, ) sess = sessP.(*session)