diff --git a/handshake/crypto_setup_client.go b/handshake/crypto_setup_client.go index 92eadba63..f69151121 100644 --- a/handshake/crypto_setup_client.go +++ b/handshake/crypto_setup_client.go @@ -40,6 +40,7 @@ type cryptoSetupClient struct { receivedSecurePacket bool secureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD + aeadChanged chan struct{} connectionParameters ConnectionParametersManager } @@ -60,6 +61,7 @@ func NewCryptoSetupClient( version protocol.VersionNumber, cryptoStream utils.Stream, connectionParameters ConnectionParametersManager, + aeadChanged chan struct{}, ) (CryptoSetup, error) { return &cryptoSetupClient{ hostname: hostname, @@ -69,6 +71,7 @@ func NewCryptoSetupClient( certManager: crypto.NewCertManager(), connectionParameters: connectionParameters, keyDerivation: crypto.DeriveKeysAESGCM, + aeadChanged: aeadChanged, }, nil } @@ -223,6 +226,8 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { return err } + h.aeadChanged <- struct{}{} + return nil } @@ -399,6 +404,8 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error { if err != nil { return err } + + h.aeadChanged <- struct{}{} } return nil diff --git a/handshake/crypto_setup_client_test.go b/handshake/crypto_setup_client_test.go index b59482c55..99045effb 100644 --- a/handshake/crypto_setup_client_test.go +++ b/handshake/crypto_setup_client_test.go @@ -121,7 +121,7 @@ var _ = Describe("Crypto setup", func() { stream = &mockStream{} certManager = &mockCertManager{} version := protocol.Version36 - csInt, err := NewCryptoSetupClient("hostname", 0, version, stream, NewConnectionParamatersManager(protocol.PerspectiveClient, version)) + csInt, err := NewCryptoSetupClient("hostname", 0, version, stream, NewConnectionParamatersManager(protocol.PerspectiveClient, version), make(chan struct{}, 1)) Expect(err).ToNot(HaveOccurred()) cs = csInt.(*cryptoSetupClient) cs.certManager = certManager @@ -353,6 +353,7 @@ var _ = Describe("Crypto setup", func() { err := cs.handleSHLOMessage(tagMap) Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message"))) Expect(cs.HandshakeComplete()).To(BeFalse()) + Expect(cs.aeadChanged).ToNot(Receive()) }) It("rejects SHLOs without a PUBS", func() { @@ -382,6 +383,7 @@ var _ = Describe("Crypto setup", func() { Expect(err).ToNot(HaveOccurred()) Expect(cs.forwardSecureAEAD).ToNot(BeNil()) Expect(cs.HandshakeComplete()).To(BeTrue()) + Expect(cs.aeadChanged).To(Receive()) }) }) @@ -541,6 +543,7 @@ var _ = Describe("Crypto setup", func() { Expect(keyDerivationCalledWith.divNonce).To(Equal(cs.diversificationNonce)) Expect(keyDerivationCalledWith.pers).To(Equal(protocol.PerspectiveClient)) Expect(cs.HandshakeComplete()).To(BeFalse()) + Expect(cs.aeadChanged).To(Receive()) }) It("uses the server nonce, if the server sent one", func() { @@ -551,18 +554,21 @@ var _ = Describe("Crypto setup", func() { Expect(cs.secureAEAD).ToNot(BeNil()) Expect(keyDerivationCalledWith.nonces).To(Equal(append(cs.nonc, cs.sno...))) Expect(cs.HandshakeComplete()).To(BeFalse()) + Expect(cs.aeadChanged).To(Receive()) }) It("doesn't create a secureAEAD if the certificate is not yet verified, even if it has all necessary values", func() { err := cs.maybeUpgradeCrypto() Expect(err).ToNot(HaveOccurred()) Expect(cs.secureAEAD).To(BeNil()) + Expect(cs.aeadChanged).ToNot(Receive()) cs.serverVerified = true // make sure we really had all necessary values before, and only serverVerified was missing err = cs.maybeUpgradeCrypto() Expect(err).ToNot(HaveOccurred()) Expect(cs.secureAEAD).ToNot(BeNil()) Expect(cs.HandshakeComplete()).To(BeFalse()) + Expect(cs.aeadChanged).To(Receive()) }) It("tries to escalate before reading a handshake message", func() { @@ -583,6 +589,7 @@ var _ = Describe("Crypto setup", func() { err := cs.SetDiversificationNonce([]byte("div")) Expect(err).ToNot(HaveOccurred()) Expect(cs.secureAEAD).ToNot(BeNil()) + Expect(cs.aeadChanged).To(Receive()) Expect(cs.HandshakeComplete()).To(BeFalse()) }) }) diff --git a/session.go b/session.go index 162c98dd7..a28575ef6 100644 --- a/session.go +++ b/session.go @@ -138,7 +138,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v p cryptoStream, _ := session.GetOrOpenStream(1) var err error - session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, session.connectionParameters) + session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, session.connectionParameters, session.aeadChanged) if err != nil { return nil, err }