diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index a5dab79d..f787a9bd 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -87,8 +87,9 @@ type cryptoSetup struct { handshakeOpener Opener handshakeSealer Sealer - opener Opener - sealer Sealer + oneRTTStream io.Writer + opener Opener + sealer Sealer // TODO: add a 1-RTT stream (used for session tickets) receivedWriteKey chan struct{} @@ -106,6 +107,7 @@ var _ CryptoSetup = &cryptoSetup{} func NewCryptoSetupClient( initialStream io.Writer, handshakeStream io.Writer, + oneRTTStream io.Writer, connID protocol.ConnectionID, chtp *ClientHelloTransportParameters, handleParams func([]byte), @@ -115,6 +117,7 @@ func NewCryptoSetupClient( cs, clientHelloWritten, err := newCryptoSetup( initialStream, handshakeStream, + oneRTTStream, connID, chtp.Marshal(), handleParams, @@ -133,6 +136,7 @@ func NewCryptoSetupClient( func NewCryptoSetupServer( initialStream io.Writer, handshakeStream io.Writer, + oneRTTStream io.Writer, connID protocol.ConnectionID, eetp *EncryptedExtensionsTransportParameters, handleParams func([]byte), @@ -142,6 +146,7 @@ func NewCryptoSetupServer( cs, _, err := newCryptoSetup( initialStream, handshakeStream, + oneRTTStream, connID, eetp.Marshal(), handleParams, @@ -159,6 +164,7 @@ func NewCryptoSetupServer( func newCryptoSetup( initialStream io.Writer, handshakeStream io.Writer, + oneRTTStream io.Writer, connID protocol.ConnectionID, paramBytes []byte, // the marshaled transport parameters handleParams func([]byte), @@ -176,6 +182,7 @@ func newCryptoSetup( initialSealer: initialSealer, initialOpener: initialOpener, handshakeStream: handshakeStream, + oneRTTStream: oneRTTStream, readEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial, handleParamsCallback: handleParams, @@ -458,8 +465,10 @@ func (h *cryptoSetup) WriteRecord(p []byte) (int, error) { return n, err case protocol.EncryptionHandshake: return h.handshakeStream.Write(p) + case protocol.Encryption1RTT: + return h.oneRTTStream.Write(p) default: - return 0, fmt.Errorf("unexpected write encryption level: %s", h.writeEncLevel) + panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel)) } } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 93a57045..6a562f54 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -7,6 +7,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "io/ioutil" "math/big" "time" @@ -69,6 +70,7 @@ var _ = Describe("Crypto Setup TLS", func() { server, err := NewCryptoSetupServer( sInitialStream, sHandshakeStream, + ioutil.Discard, protocol.ConnectionID{}, &EncryptedExtensionsTransportParameters{ NegotiatedVersion: protocol.VersionTLS, @@ -99,6 +101,7 @@ var _ = Describe("Crypto Setup TLS", func() { server, err := NewCryptoSetupServer( sInitialStream, sHandshakeStream, + ioutil.Discard, protocol.ConnectionID{}, &EncryptedExtensionsTransportParameters{ NegotiatedVersion: protocol.VersionTLS, @@ -128,6 +131,7 @@ var _ = Describe("Crypto Setup TLS", func() { server, err := NewCryptoSetupServer( sInitialStream, sHandshakeStream, + ioutil.Discard, protocol.ConnectionID{}, &EncryptedExtensionsTransportParameters{ NegotiatedVersion: protocol.VersionTLS, @@ -208,6 +212,7 @@ var _ = Describe("Crypto Setup TLS", func() { client, _, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, + ioutil.Discard, protocol.ConnectionID{}, &ClientHelloTransportParameters{ InitialVersion: protocol.VersionTLS, @@ -222,6 +227,7 @@ var _ = Describe("Crypto Setup TLS", func() { server, err := NewCryptoSetupServer( sInitialStream, sHandshakeStream, + ioutil.Discard, protocol.ConnectionID{}, &EncryptedExtensionsTransportParameters{ NegotiatedVersion: protocol.VersionTLS, @@ -258,6 +264,7 @@ var _ = Describe("Crypto Setup TLS", func() { client, chChan, err := NewCryptoSetupClient( cInitialStream, cHandshakeStream, + ioutil.Discard, protocol.ConnectionID{}, &ClientHelloTransportParameters{ InitialVersion: protocol.VersionTLS, @@ -288,58 +295,56 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(done).Should(BeClosed()) }) - // It("receives transport parameters", func() { - // var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters - // cChunkChan, cInitialStream, cHandshakeStream := initStreams() - // cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second} - // client, _, err := NewCryptoSetupClient( - // cInitialStream, - // cHandshakeStream, - // nil, - // protocol.ConnectionID{}, - // cTransportParameters, - // func(p *TransportParameters) { sTransportParametersRcvd = p }, - // clientConf, - // protocol.VersionTLS, - // []protocol.VersionNumber{protocol.VersionTLS}, - // protocol.VersionTLS, - // utils.DefaultLogger.WithPrefix("client"), - // protocol.PerspectiveClient, - // ) - // Expect(err).ToNot(HaveOccurred()) + It("receives transport parameters", func() { + var cTransportParametersRcvd, sTransportParametersRcvd []byte + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second} + client, _, err := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + ioutil.Discard, + protocol.ConnectionID{}, + &ClientHelloTransportParameters{Parameters: *cTransportParameters}, + func(p []byte) { sTransportParametersRcvd = p }, + clientConf, + utils.DefaultLogger.WithPrefix("client"), + ) + Expect(err).ToNot(HaveOccurred()) - // sChunkChan, sInitialStream, sHandshakeStream := initStreams() - // sTransportParameters := &TransportParameters{ - // IdleTimeout: 0x1337 * time.Second, - // StatelessResetToken: bytes.Repeat([]byte{42}, 16), - // } - // server, err := NewCryptoSetupServer( - // sInitialStream, - // sHandshakeStream, - // protocol.ConnectionID{}, - // sTransportParameters, - // func(p *TransportParameters) { cTransportParametersRcvd = p }, - // testdata.GetTLSConfig(), - // []protocol.VersionNumber{protocol.VersionTLS}, - // protocol.VersionTLS, - // utils.DefaultLogger.WithPrefix("server"), - // protocol.PerspectiveServer, - // ) - // Expect(err).ToNot(HaveOccurred()) + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + sTransportParameters := &TransportParameters{ + IdleTimeout: 0x1337 * time.Second, + StatelessResetToken: bytes.Repeat([]byte{42}, 16), + } + server, err := NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + ioutil.Discard, + protocol.ConnectionID{}, + &EncryptedExtensionsTransportParameters{Parameters: *sTransportParameters}, + func(p []byte) { cTransportParametersRcvd = p }, + testdata.GetTLSConfig(), + utils.DefaultLogger.WithPrefix("server"), + ) + Expect(err).ToNot(HaveOccurred()) - // done := make(chan struct{}) - // go func() { - // defer GinkgoRecover() - // clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan) - // Expect(clientErr).ToNot(HaveOccurred()) - // Expect(serverErr).ToNot(HaveOccurred()) - // close(done) - // }() - // Eventually(done).Should(BeClosed()) - // Expect(cTransportParametersRcvd).ToNot(BeNil()) - // Expect(cTransportParametersRcvd.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout)) - // Expect(sTransportParametersRcvd).ToNot(BeNil()) - // Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout)) - // }) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + close(done) + }() + Eventually(done).Should(BeClosed()) + Expect(cTransportParametersRcvd).ToNot(BeNil()) + chtp := &ClientHelloTransportParameters{} + Expect(chtp.Unmarshal(cTransportParametersRcvd)).To(Succeed()) + Expect(chtp.Parameters.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout)) + Expect(sTransportParametersRcvd).ToNot(BeNil()) + eetp := &EncryptedExtensionsTransportParameters{} + Expect(eetp.Unmarshal(sTransportParametersRcvd)).To(Succeed()) + Expect(eetp.Parameters.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout)) + }) }) }) diff --git a/session.go b/session.go index 47151481..2358ccb0 100644 --- a/session.go +++ b/session.go @@ -168,8 +168,6 @@ var newSession = func( } s.preSetup() s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.logger) - initialStream := newCryptoStream() - handshakeStream := newCryptoStream() s.streamsMap = newStreamsMap( s, s.newFlowController, @@ -179,6 +177,9 @@ var newSession = func( s.version, ) s.framer = newFramer(s.streamsMap, s.version) + initialStream := newCryptoStream() + handshakeStream := newCryptoStream() + oneRTTStream := newPostHandshakeCryptoStream(s.framer) eetp := &handshake.EncryptedExtensionsTransportParameters{ NegotiatedVersion: s.version, SupportedVersions: protocol.GetGreasedVersions(conf.Versions), @@ -187,6 +188,7 @@ var newSession = func( cs, err := handshake.NewCryptoSetupServer( initialStream, handshakeStream, + oneRTTStream, clientDestConnID, eetp, s.processTransportParameters, @@ -249,6 +251,7 @@ var newClientSession = func( s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.logger) initialStream := newCryptoStream() handshakeStream := newCryptoStream() + oneRTTStream := newPostHandshakeCryptoStream(s.framer) chtp := &handshake.ClientHelloTransportParameters{ InitialVersion: initialVersion, Parameters: *params, @@ -256,6 +259,7 @@ var newClientSession = func( cs, clientHelloWritten, err := handshake.NewCryptoSetupClient( initialStream, handshakeStream, + oneRTTStream, s.destConnID, chtp, s.processTransportParameters,