diff --git a/handshake/crypto_setup_client.go b/handshake/crypto_setup_client.go index a4acab937..cb43b49d8 100644 --- a/handshake/crypto_setup_client.go +++ b/handshake/crypto_setup_client.go @@ -46,7 +46,7 @@ type cryptoSetupClient struct { receivedSecurePacket bool secureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD - aeadChanged chan struct{} + aeadChanged chan protocol.EncryptionLevel connectionParameters ConnectionParametersManager } @@ -67,7 +67,7 @@ func NewCryptoSetupClient( cryptoStream io.ReadWriter, tlsConfig *tls.Config, connectionParameters ConnectionParametersManager, - aeadChanged chan struct{}, + aeadChanged chan protocol.EncryptionLevel, negotiatedVersions []protocol.VersionNumber, ) (CryptoSetup, error) { return &cryptoSetupClient{ @@ -245,7 +245,7 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { return qerr.InvalidCryptoMessageParameter } - h.aeadChanged <- struct{}{} + h.aeadChanged <- protocol.EncryptionForwardSecure return nil } @@ -460,7 +460,7 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error { return err } - h.aeadChanged <- struct{}{} + h.aeadChanged <- protocol.EncryptionSecure } return nil diff --git a/handshake/crypto_setup_client_test.go b/handshake/crypto_setup_client_test.go index b206db002..bece3b408 100644 --- a/handshake/crypto_setup_client_test.go +++ b/handshake/crypto_setup_client_test.go @@ -127,7 +127,7 @@ var _ = Describe("Crypto setup", func() { stream = &mockStream{} certManager = &mockCertManager{} version := protocol.Version36 - csInt, err := NewCryptoSetupClient("hostname", 0, version, stream, nil, NewConnectionParamatersManager(protocol.PerspectiveClient, version), make(chan struct{}, 1), nil) + csInt, err := NewCryptoSetupClient("hostname", 0, version, stream, nil, NewConnectionParamatersManager(protocol.PerspectiveClient, version), make(chan protocol.EncryptionLevel, 2), nil) Expect(err).ToNot(HaveOccurred()) cs = csInt.(*cryptoSetupClient) cs.certManager = certManager diff --git a/handshake/crypto_setup_server.go b/handshake/crypto_setup_server.go index 2293716a9..7176b8617 100644 --- a/handshake/crypto_setup_server.go +++ b/handshake/crypto_setup_server.go @@ -31,7 +31,7 @@ type cryptoSetupServer struct { forwardSecureAEAD crypto.AEAD receivedForwardSecurePacket bool receivedSecurePacket bool - aeadChanged chan struct{} + aeadChanged chan protocol.EncryptionLevel keyDerivation KeyDerivationFunction keyExchange KeyExchangeFunction @@ -53,7 +53,7 @@ func NewCryptoSetup( scfg *ServerConfig, cryptoStream io.ReadWriter, connectionParametersManager ConnectionParametersManager, - aeadChanged chan struct{}, + aeadChanged chan protocol.EncryptionLevel, ) (CryptoSetup, error) { return &cryptoSetupServer{ connID: connID, @@ -357,7 +357,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T WriteHandshakeMessage(&reply, TagSHLO, replyMap) utils.Debugf("Sending SHLO:\n%s", printHandshakeMessage(replyMap)) - h.aeadChanged <- struct{}{} + h.aeadChanged <- protocol.EncryptionForwardSecure return reply.Bytes(), nil } diff --git a/handshake/crypto_setup_server_test.go b/handshake/crypto_setup_server_test.go index e1b8eb1f8..f90a84677 100644 --- a/handshake/crypto_setup_server_test.go +++ b/handshake/crypto_setup_server_test.go @@ -141,7 +141,7 @@ var _ = Describe("Crypto setup", func() { cs *cryptoSetupServer stream *mockStream cpm ConnectionParametersManager - aeadChanged chan struct{} + aeadChanged chan protocol.EncryptionLevel nonce32 []byte versionTag []byte sourceAddr []byte @@ -157,7 +157,7 @@ var _ = Describe("Crypto setup", func() { Expect(err).NotTo(HaveOccurred()) expectedInitialNonceLen = 32 expectedFSNonceLen = 64 - aeadChanged = make(chan struct{}, 1) + aeadChanged = make(chan protocol.EncryptionLevel, 2) stream = &mockStream{} kex = &mockKEX{} signer = &mockSigner{} diff --git a/session.go b/session.go index e6447d283..53271e782 100644 --- a/session.go +++ b/session.go @@ -77,7 +77,7 @@ type session struct { closed uint32 // atomic bool undecryptablePackets []*receivedPacket - aeadChanged chan struct{} + aeadChanged chan protocol.EncryptionLevel nextAckScheduledTime time.Time @@ -178,7 +178,7 @@ func (s *session) setup() { s.closeChan = make(chan *qerr.QuicError, 1) s.sendingScheduled = make(chan struct{}, 1) s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) - s.aeadChanged = make(chan struct{}, 1) + s.aeadChanged = make(chan protocol.EncryptionLevel, 2) s.runClosed = make(chan struct{}, 1) s.timer = time.NewTimer(0) @@ -235,9 +235,9 @@ runLoop: // This is a bit unclean, but works properly, since the packet always // begins with the public header and we never copy it. putPacketBuffer(p.publicHeader.Raw) - case <-s.aeadChanged: + case l := <-s.aeadChanged: s.tryDecryptingQueuedPackets() - s.cryptoChangeCallback(s, s.cryptoSetup.HandshakeComplete()) + s.cryptoChangeCallback(s, l == protocol.EncryptionForwardSecure) } if err != nil { diff --git a/session_test.go b/session_test.go index 40d703881..519cc300b 100644 --- a/session_test.go +++ b/session_test.go @@ -1146,8 +1146,7 @@ var _ = Describe("Session", func() { callbackSession = s } sess.cryptoChangeCallback = cb - sess.cryptoSetup = &mockCryptoSetup{handshakeComplete: false} - sess.aeadChanged <- struct{}{} + sess.aeadChanged <- protocol.EncryptionSecure go sess.run() Eventually(func() bool { return callbackCalled }).Should(BeTrue()) Expect(callbackCalledWith).To(BeFalse()) @@ -1164,7 +1163,7 @@ var _ = Describe("Session", func() { } sess.cryptoChangeCallback = cb sess.cryptoSetup = &mockCryptoSetup{handshakeComplete: true} - sess.aeadChanged <- struct{}{} + sess.aeadChanged <- protocol.EncryptionForwardSecure go sess.run() Eventually(func() bool { return callbackCalledWith }).Should(BeTrue()) Expect(callbackSession).To(Equal(sess))