From 70cb67bd2a6ec7316e9be245c535a6201a0d4457 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 11 Nov 2018 11:00:40 +0700 Subject: [PATCH] check that the server sends the correct original_connection_id --- client.go | 7 ++++-- client_test.go | 22 ++++++++++++++++++- internal/handshake/crypto_setup.go | 2 ++ internal/handshake/crypto_setup_test.go | 3 +++ .../handshake/tls_extension_handler_client.go | 15 ++++++++++--- .../tls_extension_handler_client_test.go | 18 ++++++++++++++- session.go | 2 ++ session_test.go | 1 + 8 files changed, 63 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index 70e490b0..5f7b3c1a 100644 --- a/client.go +++ b/client.go @@ -34,8 +34,9 @@ type client struct { tlsConf *tls.Config config *Config - srcConnID protocol.ConnectionID - destConnID protocol.ConnectionID + srcConnID protocol.ConnectionID + destConnID protocol.ConnectionID + origDestConnID protocol.ConnectionID // the destination conn ID used on the first Initial (before a Retry) initialVersion protocol.VersionNumber version protocol.VersionNumber @@ -390,6 +391,7 @@ func (c *client) handleRetryPacket(hdr *wire.Header) { c.logger.Debugf("Ignoring Retry, since a Retry was already received.") return } + c.origDestConnID = c.destConnID c.destConnID = hdr.SrcConnectionID c.token = hdr.Token c.session.destroy(errCloseSessionForRetry) @@ -417,6 +419,7 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error { c.conn, runner, c.token, + c.origDestConnID, c.destConnID, c.srcConnID, c.config, diff --git a/client_test.go b/client_test.go index 33b58144..afda566a 100644 --- a/client_test.go +++ b/client_test.go @@ -34,6 +34,7 @@ var _ = Describe("Client", func() { conn connection, runner sessionRunner, token []byte, + origDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, conf *Config, @@ -138,6 +139,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -167,6 +169,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, tlsConf *tls.Config, _ *handshake.TransportParameters, @@ -196,6 +199,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -232,6 +236,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -271,6 +276,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -315,6 +321,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -355,6 +362,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -469,6 +477,7 @@ var _ = Describe("Client", func() { tokenP []byte, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, configP *Config, _ *tls.Config, params *handshake.TransportParameters, @@ -528,7 +537,8 @@ var _ = Describe("Client", func() { conn connection, _ sessionRunner, _ []byte, // token - _ protocol.ConnectionID, + origDestConnID protocol.ConnectionID, + destConnID protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, _ *tls.Config, @@ -537,6 +547,14 @@ var _ = Describe("Client", func() { _ utils.Logger, _ protocol.VersionNumber, ) (quicSession, error) { + switch len(sessions) { + case 2: // for the first session + Expect(origDestConnID).To(BeNil()) + Expect(destConnID).ToNot(BeNil()) + case 1: // for the second session + Expect(origDestConnID).To(Equal(connID)) + Expect(destConnID).ToNot(Equal(connID)) + } return <-sessions, nil } _, err := Dial(packetConn, addr, "localhost:1337", nil, config) @@ -589,6 +607,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -627,6 +646,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index f9e44a67..f8ab4c1a 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -99,6 +99,7 @@ var _ CryptoSetup = &cryptoSetup{} func NewCryptoSetupClient( initialStream io.Writer, handshakeStream io.Writer, + origConnID protocol.ConnectionID, connID protocol.ConnectionID, params *TransportParameters, handleParams func(*TransportParameters), @@ -111,6 +112,7 @@ func NewCryptoSetupClient( ) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) { extHandler, receivedTransportParams := newExtensionHandlerClient( params, + origConnID, initialVersion, supportedVersions, currentVersion, diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index a2cf5649..bc8e224b 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -199,6 +199,7 @@ var _ = Describe("Crypto Setup TLS", func() { client, _, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, + nil, protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, @@ -254,6 +255,7 @@ var _ = Describe("Crypto Setup TLS", func() { client, chChan, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, + nil, protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, @@ -293,6 +295,7 @@ var _ = Describe("Crypto Setup TLS", func() { client, _, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, + nil, protocol.ConnectionID{}, cTransportParameters, func(p *TransportParameters) { sTransportParametersRcvd = p }, diff --git a/internal/handshake/tls_extension_handler_client.go b/internal/handshake/tls_extension_handler_client.go index 182b256a..71dfa949 100644 --- a/internal/handshake/tls_extension_handler_client.go +++ b/internal/handshake/tls_extension_handler_client.go @@ -2,6 +2,7 @@ package handshake import ( "errors" + "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -13,6 +14,7 @@ type extensionHandlerClient struct { ourParams *TransportParameters paramsChan chan<- TransportParameters + origConnID protocol.ConnectionID initialVersion protocol.VersionNumber supportedVersions []protocol.VersionNumber version protocol.VersionNumber @@ -25,6 +27,7 @@ var _ tlsExtensionHandler = &extensionHandlerClient{} // newExtensionHandlerClient creates a new extension handler for the client. func newExtensionHandlerClient( params *TransportParameters, + origConnID protocol.ConnectionID, initialVersion protocol.VersionNumber, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, @@ -37,6 +40,7 @@ func newExtensionHandlerClient( return &extensionHandlerClient{ ourParams: params, paramsChan: paramsChan, + origConnID: origConnID, initialVersion: initialVersion, supportedVersions: supportedVersions, version: version, @@ -94,11 +98,16 @@ func (h *extensionHandlerClient) ReceivedExtensions(msgType uint8, exts []qtls.E } } + params := eetp.Parameters // check that the server sent a stateless reset token - if len(eetp.Parameters.StatelessResetToken) == 0 { + if len(params.StatelessResetToken) == 0 { return errors.New("server didn't sent stateless_reset_token") } - h.logger.Debugf("Received Transport Parameters: %s", &eetp.Parameters) - h.paramsChan <- eetp.Parameters + // check the Retry token + if !h.origConnID.Equal(params.OriginalConnectionID) { + return fmt.Errorf("expected original_connection_id to equal %s, is %s", h.origConnID, params.OriginalConnectionID) + } + h.logger.Debugf("Received Transport Parameters: %s", ¶ms) + h.paramsChan <- params return nil } diff --git a/internal/handshake/tls_extension_handler_client_test.go b/internal/handshake/tls_extension_handler_client_test.go index d191ec79..66fc150e 100644 --- a/internal/handshake/tls_extension_handler_client_test.go +++ b/internal/handshake/tls_extension_handler_client_test.go @@ -23,6 +23,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() { var h tlsExtensionHandler h, paramsChan = newExtensionHandlerClient( &TransportParameters{}, + nil, version, nil, version, @@ -105,13 +106,28 @@ var _ = Describe("TLS Extension Handler, for the client", func() { Expect(err).To(HaveOccurred()) // this will be some kind of decoding error }) - It("rejects TransportParameters if they don't contain the stateless reset token", func() { + It("errors if the TransportParameters don't contain the stateless reset token", func() { parameters.StatelessResetToken = nil ext := getEncryptedExtensions(parameters) err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) Expect(err).To(MatchError("server didn't sent stateless_reset_token")) }) + It("errors if the TransportParameters contain an original_connection_id, although no Retry was performed", func() { + parameters.OriginalConnectionID = protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + ext := getEncryptedExtensions(parameters) + err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) + Expect(err).To(MatchError("expected original_connection_id to equal (empty), is 0xdecafbad")) + }) + + It("errors if the TransportParameters contain a wrong original_connection_id", func() { + handler.origConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + parameters.OriginalConnectionID = protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + ext := getEncryptedExtensions(parameters) + err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) + Expect(err).To(MatchError("expected original_connection_id to equal 0xdeadbeef, is 0xdecafbad")) + }) + Context("Version Negotiation", func() { It("accepts a valid version negotiation", func() { done := make(chan struct{}) diff --git a/session.go b/session.go index cea7cc9d..8842a4b9 100644 --- a/session.go +++ b/session.go @@ -205,6 +205,7 @@ var newClientSession = func( conn connection, runner sessionRunner, token []byte, + origDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, conf *Config, @@ -231,6 +232,7 @@ var newClientSession = func( cs, clientHelloWritten, err := handshake.NewCryptoSetupClient( initialStream, handshakeStream, + origDestConnID, s.destConnID, params, s.processTransportParameters, diff --git a/session_test.go b/session_test.go index 14f8e704..607cbb12 100644 --- a/session_test.go +++ b/session_test.go @@ -1288,6 +1288,7 @@ var _ = Describe("Client Session", func() { []byte("token"), protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, + protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, populateClientConfig(&Config{}, true), nil, // tls.Config nil, // transport parameters