diff --git a/handshake/crypto_setup_server.go b/handshake/crypto_setup_server.go index 269e66732..96170d0a5 100644 --- a/handshake/crypto_setup_server.go +++ b/handshake/crypto_setup_server.go @@ -38,8 +38,8 @@ type cryptoSetupServer struct { secureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD receivedForwardSecurePacket bool - sentSHLO bool receivedSecurePacket bool + sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written aeadChanged chan<- protocol.EncryptionLevel keyDerivation KeyDerivationFunction @@ -93,6 +93,7 @@ func NewCryptoSetup( cryptoStream: cryptoStream, connectionParameters: connectionParametersManager, acceptSTKCallback: acceptSTK, + sentSHLO: make(chan struct{}), aeadChanged: aeadChanged, }, nil } @@ -167,10 +168,11 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] if err != nil { return false, err } - _, err = h.cryptoStream.Write(reply) - if err != nil { + if _, err := h.cryptoStream.Write(reply); err != nil { return false, err } + h.aeadChanged <- protocol.EncryptionForwardSecure + close(h.sentSHLO) return true, nil } @@ -193,6 +195,8 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu if err == nil { if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client h.receivedForwardSecurePacket = true + // wait until protocol.EncryptionForwardSecure was sent on the aeadChan + <-h.sentSHLO close(h.aeadChanged) } return res, protocol.EncryptionForwardSecure, nil @@ -451,9 +455,6 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T var reply bytes.Buffer message.Write(&reply) utils.Debugf("Sending %s", message) - - h.aeadChanged <- protocol.EncryptionForwardSecure - return reply.Bytes(), nil } diff --git a/handshake/crypto_setup_server_test.go b/handshake/crypto_setup_server_test.go index 304f9f34e..2fd411751 100644 --- a/handshake/crypto_setup_server_test.go +++ b/handshake/crypto_setup_server_test.go @@ -350,7 +350,6 @@ var _ = Describe("Server Crypto Setup", func() { Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO")) Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) - Expect(aeadChanged).ToNot(Receive()) Expect(aeadChanged).ToNot(BeClosed()) }) @@ -384,6 +383,7 @@ var _ = Describe("Server Crypto Setup", func() { Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) + Expect(aeadChanged).ToNot(BeClosed()) }) It("recognizes inchoate CHLOs missing SCID", func() { @@ -554,6 +554,8 @@ var _ = Describe("Server Crypto Setup", func() { TagKEXS: kexs, }) Expect(err).ToNot(HaveOccurred()) + Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) + close(cs.sentSHLO) } Context("null encryption", func() { @@ -654,8 +656,6 @@ var _ = Describe("Server Crypto Setup", func() { doCHLO() _, _, err := cs.Open(nil, []byte("forward secure encrypted"), 0, []byte{}) Expect(err).ToNot(HaveOccurred()) - Expect(aeadChanged).To(Receive()) // consume the protocol.EncryptionSecure - Expect(aeadChanged).To(Receive()) // consume the protocol.EncryptionForwardSecure Expect(aeadChanged).To(BeClosed()) }) })