diff --git a/handshake/crypto_setup_client.go b/handshake/crypto_setup_client.go index afb3d9ab..acb8144c 100644 --- a/handshake/crypto_setup_client.go +++ b/handshake/crypto_setup_client.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/binary" "errors" + "fmt" "io" "time" @@ -32,10 +33,11 @@ type cryptoSetupClient struct { lastSentCHLO []byte certManager crypto.CertManager - serverVerified bool // has the certificate chain and the proof already been verified - keyDerivation KeyDerivationFunction - secureAEAD crypto.AEAD - forwardSecureAEAD crypto.AEAD + clientHelloCounter int + serverVerified bool // has the certificate chain and the proof already been verified + keyDerivation KeyDerivationFunction + secureAEAD crypto.AEAD + forwardSecureAEAD crypto.AEAD } var _ crypto.AEAD = &cryptoSetupClient{} @@ -72,9 +74,12 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { return err } - err = h.sendCHLO() - if err != nil { - return err + // send CHLOs until the forward secure encryption is established + if h.forwardSecureAEAD == nil { + err = h.sendCHLO() + if err != nil { + return err + } } var shloData bytes.Buffer @@ -260,6 +265,11 @@ func (h *cryptoSetupClient) HandshakeComplete() bool { } func (h *cryptoSetupClient) sendCHLO() error { + h.clientHelloCounter++ + if h.clientHelloCounter > protocol.MaxClientHellos { + return qerr.Error(qerr.CryptoTooManyRejects, fmt.Sprintf("More than %d rejects", protocol.MaxClientHellos)) + } + b := &bytes.Buffer{} tags := h.getTags() diff --git a/handshake/crypto_setup_client_test.go b/handshake/crypto_setup_client_test.go index fe2df4de..57b898af 100644 --- a/handshake/crypto_setup_client_test.go +++ b/handshake/crypto_setup_client_test.go @@ -403,6 +403,17 @@ var _ = Describe("Crypto setup", func() { Expect(tags[TagNONC]).To(Equal(cs.nonc)) Expect(tags[TagPUBS]).To(Equal(kex.PublicKey())) }) + + It("doesn't send more than MaxClientHellos CHLOs", func() { + Expect(cs.clientHelloCounter).To(BeZero()) + for i := 1; i <= protocol.MaxClientHellos; i++ { + err := cs.sendCHLO() + Expect(err).ToNot(HaveOccurred()) + Expect(cs.clientHelloCounter).To(Equal(i)) + } + err := cs.sendCHLO() + Expect(err).To(MatchError(qerr.Error(qerr.CryptoTooManyRejects, fmt.Sprintf("More than %d rejects", protocol.MaxClientHellos)))) + }) }) Context("escalating crypto", func() { diff --git a/protocol/protocol.go b/protocol/protocol.go index aae2e0a8..bb36e3be 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -64,3 +64,9 @@ const MaxRetransmissionTime = 60 * time.Second // ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have. const ClientHelloMinimumSize = 1024 + +// MaxClientHellos is the maximum number of times we'll send a client hello +// The value 3 accounts for: +// * one failure due to an incorrect or missing source-address token +// * one failure due the server's certificate chain being unavailible and the server being unwilling to send it without a valid source-address token +const MaxClientHellos = 3