diff --git a/handshake/crypto_setup.go b/handshake/crypto_setup.go index 6eec9502f..ec8326e9b 100644 --- a/handshake/crypto_setup.go +++ b/handshake/crypto_setup.go @@ -174,6 +174,9 @@ func (h *CryptoSetup) isInchoateCHLO(cryptoData map[Tag][]byte) bool { if !ok || !bytes.Equal(h.scfg.ID, scid) { return true } + if _, ok := cryptoData[TagPUBS]; !ok { + return true + } if err := h.scfg.stkSource.VerifyToken(h.ip, cryptoData[TagSTK]); err != nil { utils.Infof("STK invalid: %s", err.Error()) return false @@ -186,36 +189,41 @@ func (h *CryptoSetup) handleInchoateCHLO(sni string, data []byte, cryptoData map return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small") } - var chloOrNil []byte - if h.version > protocol.Version30 { - chloOrNil = data - } - - proof, err := h.scfg.Sign(sni, chloOrNil) - if err != nil { - return nil, err - } - - commonSetHashes := cryptoData[TagCCS] - cachedCertsHashes := cryptoData[TagCCRT] - - certCompressed, err := h.scfg.GetCertsCompressed(sni, commonSetHashes, cachedCertsHashes) - if err != nil { - return nil, err - } - token, err := h.scfg.stkSource.NewToken(h.ip) if err != nil { return nil, err } - var serverReply bytes.Buffer - WriteHandshakeMessage(&serverReply, TagREJ, map[Tag][]byte{ + replyMap := map[Tag][]byte{ TagSCFG: h.scfg.Get(), - TagCERT: certCompressed, - TagPROF: proof, TagSTK: token, - }) + } + + if h.scfg.stkSource.VerifyToken(h.ip, cryptoData[TagSTK]) == nil { + var chloOrNil []byte + if h.version > protocol.Version30 { + chloOrNil = data + } + + proof, err := h.scfg.Sign(sni, chloOrNil) + if err != nil { + return nil, err + } + + commonSetHashes := cryptoData[TagCCS] + cachedCertsHashes := cryptoData[TagCCRT] + + certCompressed, err := h.scfg.GetCertsCompressed(sni, commonSetHashes, cachedCertsHashes) + if err != nil { + return nil, err + } + // Token was valid, send more details + replyMap[TagPROF] = proof + replyMap[TagCERT] = certCompressed + } + + var serverReply bytes.Buffer + WriteHandshakeMessage(&serverReply, TagREJ, replyMap) return serverReply.Bytes(), nil } diff --git a/handshake/crypto_setup_test.go b/handshake/crypto_setup_test.go index 25f6d3162..00492e54f 100644 --- a/handshake/crypto_setup_test.go +++ b/handshake/crypto_setup_test.go @@ -203,8 +203,28 @@ var _ = Describe("Crypto setup", func() { response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil) Expect(err).ToNot(HaveOccurred()) Expect(response).To(HavePrefix("REJ")) - Expect(response).To(ContainSubstring("certcompressed")) Expect(response).To(ContainSubstring("initial public")) + Expect(signer.gotCHLO).To(BeFalse()) + }) + + It("REJ messages don't include cert or proof without STK", func() { + response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil) + Expect(err).ToNot(HaveOccurred()) + Expect(response).To(HavePrefix("REJ")) + Expect(response).ToNot(ContainSubstring("certcompressed")) + Expect(response).ToNot(ContainSubstring("proof")) + Expect(signer.gotCHLO).To(BeFalse()) + }) + + It("REJ messages include cert and proof with valid STK", func() { + response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{ + TagSTK: validSTK, + TagSNI: []byte("foo"), + }) + Expect(err).ToNot(HaveOccurred()) + Expect(response).To(HavePrefix("REJ")) + Expect(response).To(ContainSubstring("certcompressed")) + Expect(response).To(ContainSubstring("proof")) Expect(signer.gotCHLO).To(BeTrue()) }) @@ -244,6 +264,7 @@ var _ = Describe("Crypto setup", func() { TagSNI: []byte("quic.clemente.io"), TagNONC: nonce32, TagSTK: validSTK, + TagPUBS: nil, }) err := cs.HandleCryptoStream() Expect(err).NotTo(HaveOccurred()) @@ -258,6 +279,7 @@ var _ = Describe("Crypto setup", func() { TagSNI: []byte("quic.clemente.io"), TagNONC: nonce32, TagSTK: validSTK, + TagPUBS: nil, }) err := cs.HandleCryptoStream() Expect(err).NotTo(HaveOccurred()) @@ -267,11 +289,18 @@ var _ = Describe("Crypto setup", func() { }) It("recognizes inchoate CHLOs missing SCID", func() { - Expect(cs.isInchoateCHLO(map[Tag][]byte{})).To(BeTrue()) + Expect(cs.isInchoateCHLO(map[Tag][]byte{TagPUBS: nil})).To(BeTrue()) + }) + + It("recognizes inchoate CHLOs missing PUBS", func() { + Expect(cs.isInchoateCHLO(map[Tag][]byte{TagSCID: nil})).To(BeTrue()) }) It("recognizes proper CHLOs", func() { - Expect(cs.isInchoateCHLO(map[Tag][]byte{TagSCID: scfg.ID})).To(BeFalse()) + Expect(cs.isInchoateCHLO(map[Tag][]byte{ + TagSCID: scfg.ID, + TagPUBS: nil, + })).To(BeFalse()) }) It("errors on too short inchoate CHLOs", func() {