diff --git a/handshake/crypto_setup.go b/handshake/crypto_setup.go index fd9de586..3668e9ab 100644 --- a/handshake/crypto_setup.go +++ b/handshake/crypto_setup.go @@ -75,10 +75,19 @@ func (h *CryptoSetup) HandleCryptoStream() error { utils.Infof("Got crypto message:\n%s", printHandshakeMessage(cryptoData)) + sniSlice, ok := cryptoData[TagSNI] + if !ok { + return errors.New("expected SNI in handshake map") + } + sni := string(sniSlice) + if sni == "" { + return errors.New("expected SNI in handshake map") + } + var reply []byte if !h.isInchoateCHLO(cryptoData) { // We have a CHLO with a proper server config ID, do a 0-RTT handshake - reply, err = h.handleCHLO(chloData, cryptoData) + reply, err = h.handleCHLO(sni, chloData, cryptoData) if err != nil { return err } @@ -90,7 +99,7 @@ func (h *CryptoSetup) HandleCryptoStream() error { } // We have an inchoate or non-matching CHLO, we now send a rejection - reply, err = h.handleInchoateCHLO(chloData) + reply, err = h.handleInchoateCHLO(sni, chloData) if err != nil { return err } @@ -155,13 +164,13 @@ func (h *CryptoSetup) isInchoateCHLO(cryptoData map[Tag][]byte) bool { return false } -func (h *CryptoSetup) handleInchoateCHLO(data []byte) ([]byte, error) { - proof, err := h.scfg.Sign("", data) +func (h *CryptoSetup) handleInchoateCHLO(sni string, data []byte) ([]byte, error) { + proof, err := h.scfg.Sign(sni, data) if err != nil { return nil, err } - certCompressed, err := h.scfg.GetCertCompressed("") + certCompressed, err := h.scfg.GetCertCompressed(sni) if err != nil { return nil, err } @@ -176,7 +185,7 @@ func (h *CryptoSetup) handleInchoateCHLO(data []byte) ([]byte, error) { return serverReply.Bytes(), nil } -func (h *CryptoSetup) handleCHLO(data []byte, cryptoData map[Tag][]byte) ([]byte, error) { +func (h *CryptoSetup) handleCHLO(sni string, data []byte, cryptoData map[Tag][]byte) ([]byte, error) { // We have a CHLO matching our server config, we can continue with the 0-RTT handshake sharedSecret, err := h.scfg.kex.CalculateSharedKey(cryptoData[TagPUBS]) if err != nil { @@ -189,7 +198,7 @@ func (h *CryptoSetup) handleCHLO(data []byte, cryptoData map[Tag][]byte) ([]byte h.mutex.Lock() defer h.mutex.Unlock() - certUncompressed, err := h.scfg.signer.GetCertUncompressed("") + certUncompressed, err := h.scfg.signer.GetCertUncompressed(sni) if err != nil { return nil, err } diff --git a/handshake/crypto_setup_test.go b/handshake/crypto_setup_test.go index 811bc633..40f6d0ca 100644 --- a/handshake/crypto_setup_test.go +++ b/handshake/crypto_setup_test.go @@ -125,7 +125,7 @@ var _ = Describe("Crypto setup", func() { Context("when responding to client messages", func() { It("generates REJ messages", func() { - response, err := cs.handleInchoateCHLO([]byte("chlo")) + response, err := cs.handleInchoateCHLO("", []byte("chlo")) Expect(err).ToNot(HaveOccurred()) Expect(response).To(HavePrefix("REJ")) Expect(response).To(ContainSubstring("certcompressed")) @@ -134,7 +134,7 @@ var _ = Describe("Crypto setup", func() { }) It("generates SHLO messages", func() { - response, err := cs.handleCHLO([]byte("chlo-data"), map[Tag][]byte{ + response, err := cs.handleCHLO("", []byte("chlo-data"), map[Tag][]byte{ TagPUBS: []byte("pubs-c"), }) Expect(err).ToNot(HaveOccurred()) @@ -151,16 +151,18 @@ var _ = Describe("Crypto setup", func() { }) It("handles long handshake", func() { - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{}) - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{TagSCID: scfg.ID, TagSNO: cs.nonce}) - cs.HandleCryptoStream() + WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{TagSNI: []byte("quic.clemente.io")}) + WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{TagSCID: scfg.ID, TagSNO: cs.nonce, TagSNI: []byte("quic.clemente.io")}) + err := cs.HandleCryptoStream() + Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ")) Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO")) }) It("handles 0-RTT handshake", func() { - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{TagSCID: scfg.ID, TagSNO: cs.nonce}) - cs.HandleCryptoStream() + WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{TagSCID: scfg.ID, TagSNO: cs.nonce, TagSNI: []byte("quic.clemente.io")}) + err := cs.HandleCryptoStream() + Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO")) Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) }) @@ -178,11 +180,17 @@ var _ = Describe("Crypto setup", func() { }) }) + It("errors without SNI", func() { + WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{}) + err := cs.HandleCryptoStream() + Expect(err).To(MatchError("expected SNI in handshake map")) + }) + Context("escalating crypto", func() { foobarFNVSigned := []byte{0x18, 0x6f, 0x44, 0xba, 0x97, 0x35, 0xd, 0x6f, 0xbf, 0x64, 0x3c, 0x79, 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72} doCHLO := func() { - _, err := cs.handleCHLO([]byte("chlo-data"), map[Tag][]byte{TagPUBS: []byte("pubs-c")}) + _, err := cs.handleCHLO("", []byte("chlo-data"), map[Tag][]byte{TagPUBS: []byte("pubs-c")}) Expect(err).ToNot(HaveOccurred()) }