diff --git a/client.go b/client.go index 89e4b0ce..1a90dacf 100644 --- a/client.go +++ b/client.go @@ -34,8 +34,9 @@ type client struct { tlsConf *tls.Config config *Config - srcConnID protocol.ConnectionID - destConnID protocol.ConnectionID + srcConnID protocol.ConnectionID + destConnID protocol.ConnectionID + origDestConnID protocol.ConnectionID // the destination conn ID used on the first Initial (before a Retry) initialVersion protocol.VersionNumber version protocol.VersionNumber @@ -390,6 +391,7 @@ func (c *client) handleRetryPacket(hdr *wire.Header) { c.logger.Debugf("Ignoring Retry, since a Retry was already received.") return } + c.origDestConnID = c.destConnID c.destConnID = hdr.SrcConnectionID c.token = hdr.Token c.session.destroy(errCloseSessionForRetry) @@ -417,6 +419,7 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error { c.conn, runner, c.token, + c.origDestConnID, c.destConnID, c.srcConnID, c.config, diff --git a/client_test.go b/client_test.go index ac7f2e49..ae2f4892 100644 --- a/client_test.go +++ b/client_test.go @@ -34,6 +34,7 @@ var _ = Describe("Client", func() { conn connection, runner sessionRunner, token []byte, + origDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, conf *Config, @@ -139,6 +140,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -168,6 +170,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, tlsConf *tls.Config, _ *handshake.TransportParameters, @@ -197,6 +200,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -233,6 +237,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -272,6 +277,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -316,6 +322,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -360,6 +367,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -474,6 +482,7 @@ var _ = Describe("Client", func() { tokenP []byte, _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, configP *Config, _ *tls.Config, params *handshake.TransportParameters, @@ -533,7 +542,8 @@ var _ = Describe("Client", func() { conn connection, _ sessionRunner, _ []byte, // token - _ protocol.ConnectionID, + origDestConnID protocol.ConnectionID, + destConnID protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, _ *tls.Config, @@ -542,6 +552,14 @@ var _ = Describe("Client", func() { _ utils.Logger, _ protocol.VersionNumber, ) (quicSession, error) { + switch len(sessions) { + case 2: // for the first session + Expect(origDestConnID).To(BeNil()) + Expect(destConnID).ToNot(BeNil()) + case 1: // for the second session + Expect(origDestConnID).To(Equal(connID)) + Expect(destConnID).ToNot(Equal(connID)) + } return <-sessions, nil } _, err := Dial(packetConn, addr, "localhost:1337", nil, config) @@ -594,6 +612,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -632,6 +651,7 @@ var _ = Describe("Client", func() { _ []byte, // token _ protocol.ConnectionID, _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, diff --git a/interface.go b/interface.go index 2567144f..5ba909cb 100644 --- a/interface.go +++ b/interface.go @@ -17,7 +17,10 @@ type StreamID = protocol.StreamID type VersionNumber = protocol.VersionNumber // A Cookie can be used to verify the ownership of the client address. -type Cookie = handshake.Cookie +type Cookie struct { + RemoteAddr string + SentTime time.Time +} // ConnectionState records basic details about the QUIC connection. type ConnectionState = handshake.ConnectionState diff --git a/internal/handshake/cookie_generator.go b/internal/handshake/cookie_generator.go index 00f6e7ef..6d1288ed 100644 --- a/internal/handshake/cookie_generator.go +++ b/internal/handshake/cookie_generator.go @@ -5,6 +5,8 @@ import ( "fmt" "net" "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" ) const ( @@ -14,14 +16,17 @@ const ( // A Cookie is derived from the client address and can be used to verify the ownership of this address. type Cookie struct { - RemoteAddr string - // The time that the STK was issued (resolution 1 second) + RemoteAddr string + OriginalDestConnectionID protocol.ConnectionID + // The time that the Cookie was issued (resolution 1 second) SentTime time.Time } // token is the struct that is used for ASN1 serialization and deserialization type token struct { - Data []byte + RemoteAddr []byte + OriginalDestConnectionID []byte + Timestamp int64 } @@ -42,10 +47,11 @@ func NewCookieGenerator() (*CookieGenerator, error) { } // NewToken generates a new Cookie for a given source address -func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) { +func (g *CookieGenerator) NewToken(raddr net.Addr, origConnID protocol.ConnectionID) ([]byte, error) { data, err := asn1.Marshal(token{ - Data: encodeRemoteAddr(raddr), - Timestamp: time.Now().Unix(), + RemoteAddr: encodeRemoteAddr(raddr), + OriginalDestConnectionID: origConnID, + Timestamp: time.Now().Unix(), }) if err != nil { return nil, err @@ -72,10 +78,14 @@ func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) { if len(rest) != 0 { return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) } - return &Cookie{ - RemoteAddr: decodeRemoteAddr(t.Data), + cookie := &Cookie{ + RemoteAddr: decodeRemoteAddr(t.RemoteAddr), SentTime: time.Unix(t.Timestamp, 0), - }, nil + } + if len(t.OriginalDestConnectionID) > 0 { + cookie.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID) + } + return cookie, nil } // encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie diff --git a/internal/handshake/cookie_generator_test.go b/internal/handshake/cookie_generator_test.go index f0480701..6c0a27b0 100644 --- a/internal/handshake/cookie_generator_test.go +++ b/internal/handshake/cookie_generator_test.go @@ -5,6 +5,8 @@ import ( "net" "time" + "github.com/lucas-clemente/quic-go/internal/protocol" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -20,7 +22,7 @@ var _ = Describe("Cookie Generator", func() { It("generates a Cookie", func() { ip := net.IPv4(127, 0, 0, 1) - token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337}) + token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337}, nil) Expect(err).ToNot(HaveOccurred()) Expect(token).ToNot(BeEmpty()) }) @@ -33,7 +35,10 @@ var _ = Describe("Cookie Generator", func() { It("accepts a valid cookie", func() { ip := net.IPv4(192, 168, 0, 1) - token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337}) + token, err := cookieGen.NewToken( + &net.UDPAddr{IP: ip, Port: 1337}, + nil, + ) Expect(err).ToNot(HaveOccurred()) cookie, err := cookieGen.DecodeToken(token) Expect(err).ToNot(HaveOccurred()) @@ -41,6 +46,18 @@ var _ = Describe("Cookie Generator", func() { // the time resolution of the Cookie is just 1 second // if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second)) + Expect(cookie.OriginalDestConnectionID).To(BeNil()) + }) + + It("saves the connection ID", func() { + token, err := cookieGen.NewToken( + &net.UDPAddr{}, + protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + ) + Expect(err).ToNot(HaveOccurred()) + cookie, err := cookieGen.DecodeToken(token) + Expect(err).ToNot(HaveOccurred()) + Expect(cookie.OriginalDestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) }) It("rejects invalid tokens", func() { @@ -56,7 +73,7 @@ var _ = Describe("Cookie Generator", func() { }) It("rejects tokens that can be decoded, but have additional payload", func() { - t, err := asn1.Marshal(token{Data: []byte("foobar")}) + t, err := asn1.Marshal(token{RemoteAddr: []byte("foobar")}) Expect(err).ToNot(HaveOccurred()) t = append(t, []byte("rest")...) enc, err := cookieGen.cookieProtector.NewToken(t) @@ -67,7 +84,7 @@ var _ = Describe("Cookie Generator", func() { // we don't generate tokens that have no data, but we should be able to handle them if we receive one for whatever reason It("doesn't panic if a tokens has no data", func() { - t, err := asn1.Marshal(token{Data: []byte("")}) + t, err := asn1.Marshal(token{RemoteAddr: []byte("")}) Expect(err).ToNot(HaveOccurred()) enc, err := cookieGen.cookieProtector.NewToken(t) Expect(err).ToNot(HaveOccurred()) @@ -86,7 +103,7 @@ var _ = Describe("Cookie Generator", func() { ip := net.ParseIP(addr) Expect(ip).ToNot(BeNil()) raddr := &net.UDPAddr{IP: ip, Port: 1337} - token, err := cookieGen.NewToken(raddr) + token, err := cookieGen.NewToken(raddr, nil) Expect(err).ToNot(HaveOccurred()) cookie, err := cookieGen.DecodeToken(token) Expect(err).ToNot(HaveOccurred()) @@ -99,7 +116,7 @@ var _ = Describe("Cookie Generator", func() { It("uses the string representation an address that is not a UDP address", func() { raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} - token, err := cookieGen.NewToken(raddr) + token, err := cookieGen.NewToken(raddr, nil) Expect(err).ToNot(HaveOccurred()) cookie, err := cookieGen.DecodeToken(token) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index b68e3413..f8ab4c1a 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -95,16 +95,11 @@ type cryptoSetup struct { var _ qtls.RecordLayer = &cryptoSetup{} var _ CryptoSetup = &cryptoSetup{} -type versionInfo struct { - initialVersion protocol.VersionNumber - supportedVersions []protocol.VersionNumber - currentVersion protocol.VersionNumber -} - // NewCryptoSetupClient creates a new crypto setup for the client func NewCryptoSetupClient( initialStream io.Writer, handshakeStream io.Writer, + origConnID protocol.ConnectionID, connID protocol.ConnectionID, params *TransportParameters, handleParams func(*TransportParameters), @@ -115,18 +110,22 @@ func NewCryptoSetupClient( logger utils.Logger, perspective protocol.Perspective, ) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) { + extHandler, receivedTransportParams := newExtensionHandlerClient( + params, + origConnID, + initialVersion, + supportedVersions, + currentVersion, + logger, + ) return newCryptoSetup( initialStream, handshakeStream, connID, - params, + extHandler, + receivedTransportParams, handleParams, tlsConf, - versionInfo{ - currentVersion: currentVersion, - initialVersion: initialVersion, - supportedVersions: supportedVersions, - }, logger, perspective, ) @@ -145,17 +144,20 @@ func NewCryptoSetupServer( logger utils.Logger, perspective protocol.Perspective, ) (CryptoSetup, error) { + extHandler, receivedTransportParams := newExtensionHandlerServer( + params, + supportedVersions, + currentVersion, + logger, + ) cs, _, err := newCryptoSetup( initialStream, handshakeStream, connID, - params, + extHandler, + receivedTransportParams, handleParams, tlsConf, - versionInfo{ - currentVersion: currentVersion, - supportedVersions: supportedVersions, - }, logger, perspective, ) @@ -166,10 +168,10 @@ func newCryptoSetup( initialStream io.Writer, handshakeStream io.Writer, connID protocol.ConnectionID, - params *TransportParameters, + extHandler tlsExtensionHandler, + transportParamChan <-chan TransportParameters, handleParams func(*TransportParameters), tlsConf *tls.Config, - versionInfo versionInfo, logger utils.Logger, perspective protocol.Perspective, ) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) { @@ -178,40 +180,23 @@ func newCryptoSetup( return nil, nil, err } cs := &cryptoSetup{ - initialStream: initialStream, - initialAEAD: initialAEAD, - handshakeStream: handshakeStream, - readEncLevel: protocol.EncryptionInitial, - writeEncLevel: protocol.EncryptionInitial, - handleParamsCallback: handleParams, - logger: logger, - perspective: perspective, - handshakeDone: make(chan struct{}), - handshakeErrChan: make(chan struct{}), - messageErrChan: make(chan error, 1), - clientHelloWrittenChan: make(chan struct{}), - messageChan: make(chan []byte, 100), - receivedReadKey: make(chan struct{}), - receivedWriteKey: make(chan struct{}), - closeChan: make(chan struct{}), - } - var extHandler tlsExtensionHandler - switch perspective { - case protocol.PerspectiveClient: - extHandler, cs.receivedTransportParams = newExtensionHandlerClient( - params, - versionInfo.initialVersion, - versionInfo.supportedVersions, - versionInfo.currentVersion, - logger, - ) - case protocol.PerspectiveServer: - extHandler, cs.receivedTransportParams = newExtensionHandlerServer( - params, - versionInfo.supportedVersions, - versionInfo.currentVersion, - logger, - ) + initialStream: initialStream, + initialAEAD: initialAEAD, + handshakeStream: handshakeStream, + readEncLevel: protocol.EncryptionInitial, + writeEncLevel: protocol.EncryptionInitial, + handleParamsCallback: handleParams, + receivedTransportParams: transportParamChan, + logger: logger, + perspective: perspective, + handshakeDone: make(chan struct{}), + handshakeErrChan: make(chan struct{}), + messageErrChan: make(chan error, 1), + clientHelloWrittenChan: make(chan struct{}), + messageChan: make(chan []byte, 100), + receivedReadKey: make(chan struct{}), + receivedWriteKey: make(chan struct{}), + closeChan: make(chan struct{}), } qtlsConf := tlsConfigToQtlsConfig(tlsConf) qtlsConf.AlternativeRecordLayer = cs diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index a2cf5649..bc8e224b 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -199,6 +199,7 @@ var _ = Describe("Crypto Setup TLS", func() { client, _, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, + nil, protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, @@ -254,6 +255,7 @@ var _ = Describe("Crypto Setup TLS", func() { client, chChan, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, + nil, protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, @@ -293,6 +295,7 @@ var _ = Describe("Crypto Setup TLS", func() { client, _, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, + nil, protocol.ConnectionID{}, cTransportParameters, func(p *TransportParameters) { sTransportParametersRcvd = p }, diff --git a/internal/handshake/tls_extension_handler_client.go b/internal/handshake/tls_extension_handler_client.go index 03d7a4ea..5e4e1c4d 100644 --- a/internal/handshake/tls_extension_handler_client.go +++ b/internal/handshake/tls_extension_handler_client.go @@ -2,6 +2,7 @@ package handshake import ( "errors" + "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" @@ -13,6 +14,7 @@ type extensionHandlerClient struct { ourParams *TransportParameters paramsChan chan<- TransportParameters + origConnID protocol.ConnectionID initialVersion protocol.VersionNumber supportedVersions []protocol.VersionNumber version protocol.VersionNumber @@ -25,6 +27,7 @@ var _ tlsExtensionHandler = &extensionHandlerClient{} // newExtensionHandlerClient creates a new extension handler for the client. func newExtensionHandlerClient( params *TransportParameters, + origConnID protocol.ConnectionID, initialVersion protocol.VersionNumber, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, @@ -37,6 +40,7 @@ func newExtensionHandlerClient( return &extensionHandlerClient{ ourParams: params, paramsChan: paramsChan, + origConnID: origConnID, initialVersion: initialVersion, supportedVersions: supportedVersions, version: version, @@ -94,11 +98,16 @@ func (h *extensionHandlerClient) ReceivedExtensions(msgType uint8, exts []qtls.E } } + params := eetp.Parameters // check that the server sent a stateless reset token - if len(eetp.Parameters.StatelessResetToken) == 0 { + if len(params.StatelessResetToken) == 0 { return errors.New("server didn't sent stateless_reset_token") } - h.logger.Debugf("Received Transport Parameters: %s", &eetp.Parameters) - h.paramsChan <- eetp.Parameters + // check the Retry token + if !h.origConnID.Equal(params.OriginalConnectionID) { + return fmt.Errorf("expected original_connection_id to equal %s, is %s", h.origConnID, params.OriginalConnectionID) + } + h.logger.Debugf("Received Transport Parameters: %s", ¶ms) + h.paramsChan <- params return nil } diff --git a/internal/handshake/tls_extension_handler_client_test.go b/internal/handshake/tls_extension_handler_client_test.go index d191ec79..66fc150e 100644 --- a/internal/handshake/tls_extension_handler_client_test.go +++ b/internal/handshake/tls_extension_handler_client_test.go @@ -23,6 +23,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() { var h tlsExtensionHandler h, paramsChan = newExtensionHandlerClient( &TransportParameters{}, + nil, version, nil, version, @@ -105,13 +106,28 @@ var _ = Describe("TLS Extension Handler, for the client", func() { Expect(err).To(HaveOccurred()) // this will be some kind of decoding error }) - It("rejects TransportParameters if they don't contain the stateless reset token", func() { + It("errors if the TransportParameters don't contain the stateless reset token", func() { parameters.StatelessResetToken = nil ext := getEncryptedExtensions(parameters) err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) Expect(err).To(MatchError("server didn't sent stateless_reset_token")) }) + It("errors if the TransportParameters contain an original_connection_id, although no Retry was performed", func() { + parameters.OriginalConnectionID = protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + ext := getEncryptedExtensions(parameters) + err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) + Expect(err).To(MatchError("expected original_connection_id to equal (empty), is 0xdecafbad")) + }) + + It("errors if the TransportParameters contain a wrong original_connection_id", func() { + handler.origConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + parameters.OriginalConnectionID = protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + ext := getEncryptedExtensions(parameters) + err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) + Expect(err).To(MatchError("expected original_connection_id to equal 0xdeadbeef, is 0xdecafbad")) + }) + Context("Version Negotiation", func() { It("accepts a valid version negotiation", func() { done := make(chan struct{}) diff --git a/internal/handshake/transport_parameter_test.go b/internal/handshake/transport_parameter_test.go index b6f71efc..7d54d7b5 100644 --- a/internal/handshake/transport_parameter_test.go +++ b/internal/handshake/transport_parameter_test.go @@ -22,8 +22,9 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42 * time.Second, + OriginalConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, } - Expect(p.String()).To(Equal("&handshake.TransportParameters{InitialMaxStreamDataBidiLocal: 0x1234, InitialMaxStreamDataBidiRemote: 0x2345, InitialMaxStreamDataUni: 0x3456, InitialMaxData: 0x4567, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s}")) + Expect(p.String()).To(Equal("&handshake.TransportParameters{OriginalConnectionID: 0xdeadbeef, InitialMaxStreamDataBidiLocal: 0x1234, InitialMaxStreamDataBidiRemote: 0x2345, InitialMaxStreamDataUni: 0x3456, InitialMaxData: 0x4567, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s}")) }) getRandomValue := func() uint64 { @@ -43,6 +44,7 @@ var _ = Describe("Transport Parameters", func() { MaxUniStreams: getRandomValue(), DisableMigration: true, StatelessResetToken: bytes.Repeat([]byte{100}, 16), + OriginalConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, } b := &bytes.Buffer{} params.marshal(b) @@ -58,6 +60,7 @@ var _ = Describe("Transport Parameters", func() { Expect(p.IdleTimeout).To(Equal(params.IdleTimeout)) Expect(p.DisableMigration).To(Equal(params.DisableMigration)) Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken)) + Expect(p.OriginalConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) }) It("errors when the stateless_reset_token has the wrong length", func() { @@ -167,4 +170,14 @@ var _ = Describe("Transport Parameters", func() { p := &TransportParameters{} Expect(p.unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError("client sent a stateless_reset_token")) }) + + It("errors if the client sent a stateless_reset_token", func() { + params := &TransportParameters{ + OriginalConnectionID: protocol.ConnectionID{0xca, 0xfe}, + } + b := &bytes.Buffer{} + params.marshal(b) + p := &TransportParameters{} + Expect(p.unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError("client sent an original_connection_id")) + }) }) diff --git a/internal/handshake/transport_parameters.go b/internal/handshake/transport_parameters.go index c6fb6355..500be008 100644 --- a/internal/handshake/transport_parameters.go +++ b/internal/handshake/transport_parameters.go @@ -15,6 +15,7 @@ import ( type transportParameterID uint16 const ( + originalConnectionIDParameterID transportParameterID = 0x0 idleTimeoutParameterID transportParameterID = 0x1 statelessResetTokenParameterID transportParameterID = 0x2 maxPacketSizeParameterID transportParameterID = 0x3 @@ -39,9 +40,11 @@ type TransportParameters struct { MaxUniStreams uint64 MaxBidiStreams uint64 - IdleTimeout time.Duration - DisableMigration bool - StatelessResetToken []byte + IdleTimeout time.Duration + DisableMigration bool + + StatelessResetToken []byte + OriginalConnectionID protocol.ConnectionID } func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective) error { @@ -86,6 +89,11 @@ func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective b := make([]byte, 16) r.Read(b) p.StatelessResetToken = b + case originalConnectionIDParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent an original_connection_id") + } + p.OriginalConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) default: r.Seek(int64(paramLen), io.SeekCurrent) } @@ -188,9 +196,15 @@ func (p *TransportParameters) marshal(b *bytes.Buffer) { utils.BigEndian.WriteUint16(b, uint16(len(p.StatelessResetToken))) // should always be 16 bytes b.Write(p.StatelessResetToken) } + // original_connection_id + if p.OriginalConnectionID.Len() > 0 { + utils.BigEndian.WriteUint16(b, uint16(originalConnectionIDParameterID)) + utils.BigEndian.WriteUint16(b, uint16(p.OriginalConnectionID.Len())) + b.Write(p.OriginalConnectionID.Bytes()) + } } // String returns a string representation, intended for logging. func (p *TransportParameters) String() string { - return fmt.Sprintf("&handshake.TransportParameters{InitialMaxStreamDataBidiLocal: %#x, InitialMaxStreamDataBidiRemote: %#x, InitialMaxStreamDataUni: %#x, InitialMaxData: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout) + return fmt.Sprintf("&handshake.TransportParameters{OriginalConnectionID: %s, InitialMaxStreamDataBidiLocal: %#x, InitialMaxStreamDataBidiRemote: %#x, InitialMaxStreamDataUni: %#x, InitialMaxData: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.OriginalConnectionID, p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout) } diff --git a/server.go b/server.go index 38b4a392..6a4f8b94 100644 --- a/server.go +++ b/server.go @@ -340,11 +340,16 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con return nil, nil, errors.New("dropping too small Initial packet") } - var cookie *handshake.Cookie + var cookie *Cookie + var origDestConnectionID protocol.ConnectionID if len(hdr.Token) > 0 { c, err := s.cookieGenerator.DecodeToken(hdr.Token) if err == nil { - cookie = c + cookie = &Cookie{ + RemoteAddr: c.RemoteAddr, + SentTime: c.SentTime, + } + origDestConnectionID = c.OriginalDestConnectionID } } if !s.config.AcceptCookie(p.remoteAddr, cookie) { @@ -359,7 +364,14 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con return nil, nil, err } s.logger.Debugf("Changing connection ID to %s.", connID) - sess, err := s.createNewSession(p.remoteAddr, hdr.DestConnectionID, hdr.SrcConnectionID, connID, hdr.Version) + sess, err := s.createNewSession( + p.remoteAddr, + origDestConnectionID, + hdr.DestConnectionID, + hdr.SrcConnectionID, + connID, + hdr.Version, + ) if err != nil { return nil, nil, err } @@ -369,7 +381,8 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con func (s *server) createNewSession( remoteAddr net.Addr, - origConnID protocol.ConnectionID, + origDestConnID protocol.ConnectionID, + clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, version protocol.VersionNumber, @@ -384,12 +397,13 @@ func (s *server) createNewSession( MaxUniStreams: uint64(s.config.MaxIncomingUniStreams), DisableMigration: true, // TODO(#855): generate a real token - StatelessResetToken: bytes.Repeat([]byte{42}, 16), + StatelessResetToken: bytes.Repeat([]byte{42}, 16), + OriginalConnectionID: origDestConnID, } sess, err := s.newSession( &conn{pconn: s.conn, currentAddr: remoteAddr}, s.sessionRunner, - origConnID, + clientDestConnID, destConnID, srcConnID, s.config, @@ -406,7 +420,7 @@ func (s *server) createNewSession( } func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { - token, err := s.cookieGenerator.NewToken(remoteAddr) + token, err := s.cookieGenerator.NewToken(remoteAddr, hdr.DestConnectionID) if err != nil { return err } diff --git a/server_test.go b/server_test.go index 7d57be38..d63ba735 100644 --- a/server_test.go +++ b/server_test.go @@ -165,7 +165,7 @@ var _ = Describe("Server", func() { close(done) return false } - token, err := serv.cookieGenerator.NewToken(raddr) + token, err := serv.cookieGenerator.NewToken(raddr, nil) Expect(err).ToNot(HaveOccurred()) serv.handlePacket(&receivedPacket{ remoteAddr: raddr, @@ -222,7 +222,7 @@ var _ = Describe("Server", func() { }) It("replies with a Retry packet, if a Cookie is required", func() { - serv.config.AcceptCookie = func(_ net.Addr, _ *handshake.Cookie) bool { return false } + serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return false } hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, @@ -244,7 +244,7 @@ var _ = Describe("Server", func() { }) It("creates a session, if no Cookie is required", func() { - serv.config.AcceptCookie = func(_ net.Addr, _ *handshake.Cookie) bool { return true } + serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return true } hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, @@ -358,7 +358,7 @@ var _ = Describe("Server", func() { sess.EXPECT().run().Do(func() {}) return sess, nil } - _, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, protocol.VersionWhatever) + _, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) Consistently(done).ShouldNot(BeClosed()) close(completeHandshake) diff --git a/session.go b/session.go index 0921af3d..2bf78e60 100644 --- a/session.go +++ b/session.go @@ -135,7 +135,7 @@ var _ streamSender = &session{} var newSession = func( conn connection, runner sessionRunner, - origConnID protocol.ConnectionID, + clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, conf *Config, @@ -170,7 +170,7 @@ var newSession = func( cs, err := handshake.NewCryptoSetupServer( initialStream, handshakeStream, - origConnID, + clientDestConnID, params, s.processTransportParameters, tlsConf, @@ -212,6 +212,7 @@ var newClientSession = func( conn connection, runner sessionRunner, token []byte, + origDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, conf *Config, @@ -238,6 +239,7 @@ var newClientSession = func( cs, clientHelloWritten, err := handshake.NewCryptoSetupClient( initialStream, handshakeStream, + origDestConnID, s.destConnID, params, s.processTransportParameters, diff --git a/session_test.go b/session_test.go index 331a8763..54b94c8d 100644 --- a/session_test.go +++ b/session_test.go @@ -1291,6 +1291,7 @@ var _ = Describe("Client Session", func() { []byte("token"), protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, + protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, populateClientConfig(&Config{}, true), nil, // tls.Config nil, // transport parameters