From ea83ca895099f208d671e0a57017da1006977269 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Wed, 11 May 2016 15:31:05 +0200 Subject: [PATCH] implement cert compression with cached certificates --- crypto/cert_compression.go | 131 ++++++++++++++++++++++++++++++++ crypto/cert_compression_test.go | 98 ++++++++++++++++++++++++ crypto/proof_rsa.go | 36 +-------- crypto/proof_rsa_test.go | 30 +------- crypto/signer.go | 2 +- handshake/crypto_setup.go | 9 ++- handshake/crypto_setup_test.go | 6 +- handshake/server_config.go | 4 +- handshake/tags.go | 4 +- 9 files changed, 248 insertions(+), 72 deletions(-) create mode 100644 crypto/cert_compression.go create mode 100644 crypto/cert_compression_test.go diff --git a/crypto/cert_compression.go b/crypto/cert_compression.go new file mode 100644 index 00000000..a3dbc02d --- /dev/null +++ b/crypto/cert_compression.go @@ -0,0 +1,131 @@ +package crypto + +import ( + "bytes" + "compress/flate" + "compress/zlib" + "encoding/binary" + "errors" + "hash/fnv" + + "github.com/lucas-clemente/quic-go/utils" +) + +type entryType uint8 + +const ( + entryCompressed entryType = 1 + entryCached entryType = 2 + entryCommon entryType = 3 +) + +type entry struct { + t entryType + h uint64 + i uint32 +} + +func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) { + res := &bytes.Buffer{} + + cachedHashes, err := splitHashes(pCachedHashes) + if err != nil { + return nil, err + } + + chainHashes := make([]uint64, len(chain)) + for i := range chain { + chainHashes[i] = hashCert(chain[i]) + } + + entries := buildEntries(chain, chainHashes, cachedHashes) + + totalUncompressedLen := 0 + for i, e := range entries { + res.WriteByte(uint8(e.t)) + switch e.t { + case entryCached: + utils.WriteUint64(res, chainHashes[i]) + case entryCompressed: + totalUncompressedLen += 4 + len(chain[i]) + } + } + res.WriteByte(0) // end of list + + if totalUncompressedLen > 0 { + gz, err := zlib.NewWriterLevelDict(res, flate.BestCompression, buildZlibDictForEntries(entries, chain)) + if err != nil { + panic(err) + } + + utils.WriteUint32(res, uint32(totalUncompressedLen)) + + for i, e := range entries { + if e.t != entryCompressed { + continue + } + lenCert := len(chain[i]) + gz.Write([]byte{ + byte(lenCert & 0xff), + byte((lenCert >> 8) & 0xff), + byte((lenCert >> 16) & 0xff), + byte((lenCert >> 24) & 0xff), + }) + gz.Write(chain[i]) + } + + gz.Close() + } + + return res.Bytes(), nil +} + +func buildEntries(chain [][]byte, chainHashes, cachedHashes []uint64) []entry { + res := make([]entry, len(chain)) +chainLoop: + for i := range chain { + // Check if hash is in cachedHashes + for j := range cachedHashes { + if chainHashes[i] == cachedHashes[j] { + res[i] = entry{t: entryCached, h: chainHashes[i]} + continue chainLoop + } + } + + res[i] = entry{t: entryCompressed} + } + return res +} + +func buildZlibDictForEntries(entries []entry, chain [][]byte) []byte { + var dict bytes.Buffer + + // First the cached and common in reverse order + for i := len(entries) - 1; i >= 0; i-- { + if entries[i].t == entryCompressed { + continue + } + dict.Write(chain[i]) + } + + dict.Write(certDictZlib) + return dict.Bytes() +} + +func splitHashes(hashes []byte) ([]uint64, error) { + if len(hashes)%8 != 0 { + return nil, errors.New("expected a multiple of 8 bytes for CCS / CCRT hashes") + } + n := len(hashes) / 8 + res := make([]uint64, n) + for i := 0; i < n; i++ { + res[i] = binary.LittleEndian.Uint64(hashes[i*8 : (i+1)*8]) + } + return res, nil +} + +func hashCert(cert []byte) uint64 { + h := fnv.New64() + h.Write(cert) + return h.Sum64() +} diff --git a/crypto/cert_compression_test.go b/crypto/cert_compression_test.go new file mode 100644 index 00000000..7a07ff19 --- /dev/null +++ b/crypto/cert_compression_test.go @@ -0,0 +1,98 @@ +package crypto + +import ( + "bytes" + "compress/flate" + "compress/zlib" + "encoding/binary" + "hash/fnv" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func byteHash(d []byte) []byte { + h := fnv.New64() + h.Write(d) + s := h.Sum64() + res := make([]byte, 8) + binary.LittleEndian.PutUint64(res, s) + return res +} + +var _ = Describe("Cert compression", func() { + It("compresses empty", func() { + compressed, err := compressChain(nil, nil, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(compressed).To(Equal([]byte{0})) + }) + + It("gives correct single cert", func() { + cert := []byte{0xde, 0xca, 0xfb, 0xad} + certZlib := &bytes.Buffer{} + z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib) + Expect(err).ToNot(HaveOccurred()) + z.Write([]byte{0x04, 0x00, 0x00, 0x00}) + z.Write(cert) + z.Close() + chain := [][]byte{cert} + compressed, err := compressChain(chain, nil, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(compressed).To(Equal(append([]byte{ + 0x01, 0x00, + 0x08, 0x00, 0x00, 0x00, + }, certZlib.Bytes()...))) + }) + + It("gives correct cert and intermediate", func() { + cert1 := []byte{0xde, 0xca, 0xfb, 0xad} + cert2 := []byte{0xde, 0xad, 0xbe, 0xef} + certZlib := &bytes.Buffer{} + z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib) + Expect(err).ToNot(HaveOccurred()) + z.Write([]byte{0x04, 0x00, 0x00, 0x00}) + z.Write(cert1) + z.Write([]byte{0x04, 0x00, 0x00, 0x00}) + z.Write(cert2) + z.Close() + chain := [][]byte{cert1, cert2} + compressed, err := compressChain(chain, nil, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(compressed).To(Equal(append([]byte{ + 0x01, 0x01, 0x00, + 0x10, 0x00, 0x00, 0x00, + }, certZlib.Bytes()...))) + }) + + It("uses cached certificates", func() { + cert := []byte{0xde, 0xca, 0xfb, 0xad} + certHash := byteHash(cert) + chain := [][]byte{cert} + compressed, err := compressChain(chain, nil, certHash) + Expect(err).ToNot(HaveOccurred()) + expected := append([]byte{0x02}, certHash...) + expected = append(expected, 0x00) + Expect(compressed).To(Equal(expected)) + }) + + It("uses cached certificates and compressed combined", func() { + cert1 := []byte{0xde, 0xca, 0xfb, 0xad} + cert2 := []byte{0xde, 0xad, 0xbe, 0xef} + cert2Hash := byteHash(cert2) + certZlib := &bytes.Buffer{} + z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, append(cert2, certDictZlib...)) + Expect(err).ToNot(HaveOccurred()) + z.Write([]byte{0x04, 0x00, 0x00, 0x00}) + z.Write(cert1) + z.Close() + chain := [][]byte{cert1, cert2} + compressed, err := compressChain(chain, nil, cert2Hash) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x01, 0x02} + expected = append(expected, cert2Hash...) + expected = append(expected, 0x00) + expected = append(expected, []byte{0x08, 0, 0, 0}...) + expected = append(expected, certZlib.Bytes()...) + Expect(compressed).To(Equal(expected)) + }) +}) diff --git a/crypto/proof_rsa.go b/crypto/proof_rsa.go index ad31afeb..c5b135d3 100644 --- a/crypto/proof_rsa.go +++ b/crypto/proof_rsa.go @@ -1,9 +1,6 @@ package crypto import ( - "bytes" - "compress/flate" - "compress/zlib" "crypto" "crypto/rand" "crypto/rsa" @@ -11,8 +8,6 @@ import ( "crypto/tls" "errors" "strings" - - "github.com/lucas-clemente/quic-go/utils" ) // rsaSigner stores a key and a certificate for the server proof @@ -57,39 +52,12 @@ func (kd *rsaSigner) SignServerProof(sni string, chlo []byte, serverConfigData [ } // GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc -func (kd *rsaSigner) GetCertsCompressed(sni string) ([]byte, error) { +func (kd *rsaSigner) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) { cert, err := kd.getCertForSNI(sni) if err != nil { return nil, err } - - b := &bytes.Buffer{} - totalUncompressedLen := 0 - for _, c := range cert.Certificate { - // Entry type compressed - b.WriteByte(1) - totalUncompressedLen += len(c) - } - // Entry type end_of_list - b.WriteByte(0) - // Data + individual lengths as uint32 - utils.WriteUint32(b, uint32(totalUncompressedLen+4*len(cert.Certificate))) - gz, err := zlib.NewWriterLevelDict(b, flate.BestCompression, certDictZlib) - if err != nil { - panic(err) - } - for _, c := range cert.Certificate { - lenCert := len(c) - gz.Write([]byte{ - byte(lenCert & 0xff), - byte((lenCert >> 8) & 0xff), - byte((lenCert >> 16) & 0xff), - byte((lenCert >> 24) & 0xff), - }) - gz.Write(c) - } - gz.Close() - return b.Bytes(), nil + return compressChain(cert.Certificate, pCommonSetHashes, pCachedHashes) } // GetLeafCert gets the leaf certificate diff --git a/crypto/proof_rsa_test.go b/crypto/proof_rsa_test.go index ef227bb8..86818a54 100644 --- a/crypto/proof_rsa_test.go +++ b/crypto/proof_rsa_test.go @@ -15,7 +15,7 @@ import ( ) var _ = Describe("ProofRsa", func() { - It("gives correct cert", func() { + It("compresses certs", func() { cert := []byte{0xde, 0xca, 0xfb, 0xad} certZlib := &bytes.Buffer{} z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib) @@ -30,7 +30,7 @@ var _ = Describe("ProofRsa", func() { }, }, } - certCompressed, err := kd.GetCertsCompressed("") + certCompressed, err := kd.GetCertsCompressed("", nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(certCompressed).To(Equal(append([]byte{ 0x01, 0x00, @@ -38,32 +38,6 @@ var _ = Describe("ProofRsa", func() { }, certZlib.Bytes()...))) }) - It("gives correct cert with intermediate", func() { - cert1 := []byte{0xde, 0xca, 0xfb, 0xad} - cert2 := []byte{0xde, 0xad, 0xbe, 0xef} - certZlib := &bytes.Buffer{} - z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib) - Expect(err).ToNot(HaveOccurred()) - z.Write([]byte{0x04, 0x00, 0x00, 0x00}) - z.Write(cert1) - z.Write([]byte{0x04, 0x00, 0x00, 0x00}) - z.Write(cert2) - z.Close() - kd := &rsaSigner{ - config: &tls.Config{ - Certificates: []tls.Certificate{ - tls.Certificate{Certificate: [][]byte{cert1, cert2}}, - }, - }, - } - certCompressed, err := kd.GetCertsCompressed("") - Expect(err).ToNot(HaveOccurred()) - Expect(certCompressed).To(Equal(append([]byte{ - 0x01, 0x01, 0x00, - 0x10, 0x00, 0x00, 0x00, - }, certZlib.Bytes()...))) - }) - It("gives valid signatures", func() { key := testdata.GetTLSConfig().Certificates[0].PrivateKey.(*rsa.PrivateKey).Public().(*rsa.PublicKey) kd, err := NewRSASigner(testdata.GetTLSConfig()) diff --git a/crypto/signer.go b/crypto/signer.go index 2243eea7..0d9ba4e3 100644 --- a/crypto/signer.go +++ b/crypto/signer.go @@ -3,6 +3,6 @@ package crypto // A Signer holds a certificate and a private key type Signer interface { SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) - GetCertsCompressed(sni string) ([]byte, error) + GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error) GetLeafCert(sni string) ([]byte, error) } diff --git a/handshake/crypto_setup.go b/handshake/crypto_setup.go index 98232d60..fc703f8e 100644 --- a/handshake/crypto_setup.go +++ b/handshake/crypto_setup.go @@ -99,7 +99,7 @@ func (h *CryptoSetup) HandleCryptoStream() error { } // We have an inchoate or non-matching CHLO, we now send a rejection - reply, err = h.handleInchoateCHLO(sni, chloData) + reply, err = h.handleInchoateCHLO(sni, chloData, cryptoData) if err != nil { return err } @@ -164,7 +164,7 @@ func (h *CryptoSetup) isInchoateCHLO(cryptoData map[Tag][]byte) bool { return false } -func (h *CryptoSetup) handleInchoateCHLO(sni string, data []byte) ([]byte, error) { +func (h *CryptoSetup) handleInchoateCHLO(sni string, data []byte, cryptoData map[Tag][]byte) ([]byte, error) { var chloOrNil []byte if h.version > protocol.VersionNumber(30) { chloOrNil = data @@ -175,7 +175,10 @@ func (h *CryptoSetup) handleInchoateCHLO(sni string, data []byte) ([]byte, error return nil, err } - certCompressed, err := h.scfg.GetCertsCompressed(sni) + commonSetHashes := cryptoData[TagCCS] + cachedCertsHashes := cryptoData[TagCCRT] + + certCompressed, err := h.scfg.GetCertsCompressed(sni, commonSetHashes, cachedCertsHashes) if err != nil { return nil, err } diff --git a/handshake/crypto_setup_test.go b/handshake/crypto_setup_test.go index 0d27ae75..13d6e767 100644 --- a/handshake/crypto_setup_test.go +++ b/handshake/crypto_setup_test.go @@ -39,7 +39,7 @@ func (s *mockSigner) SignServerProof(sni string, chlo []byte, serverConfigData [ } return []byte("proof"), nil } -func (*mockSigner) GetCertsCompressed(sni string) ([]byte, error) { +func (*mockSigner) GetCertsCompressed(sni string, common, cached []byte) ([]byte, error) { return []byte("certcompressed"), nil } func (*mockSigner) GetLeafCert(sni string) ([]byte, error) { @@ -125,7 +125,7 @@ var _ = Describe("Crypto setup", func() { Context("when responding to client messages", func() { It("generates REJ messages", func() { - response, err := cs.handleInchoateCHLO("", []byte("chlo")) + response, err := cs.handleInchoateCHLO("", []byte("chlo"), nil) Expect(err).ToNot(HaveOccurred()) Expect(response).To(HavePrefix("REJ")) Expect(response).To(ContainSubstring("certcompressed")) @@ -135,7 +135,7 @@ var _ = Describe("Crypto setup", func() { It("generates REJ messages for version 30", func() { cs.version = protocol.VersionNumber(30) - _, err := cs.handleInchoateCHLO("", sampleCHLO) + _, err := cs.handleInchoateCHLO("", sampleCHLO, nil) Expect(err).ToNot(HaveOccurred()) Expect(signer.gotCHLO).To(BeFalse()) }) diff --git a/handshake/server_config.go b/handshake/server_config.go index 157151d4..02607448 100644 --- a/handshake/server_config.go +++ b/handshake/server_config.go @@ -50,6 +50,6 @@ func (s *ServerConfig) Sign(sni string, chlo []byte) ([]byte, error) { } // GetCertsCompressed returns the certificate data -func (s *ServerConfig) GetCertsCompressed(sni string) ([]byte, error) { - return s.signer.GetCertsCompressed(sni) +func (s *ServerConfig) GetCertsCompressed(sni string, commonSetHashes, compressedHashes []byte) ([]byte, error) { + return s.signer.GetCertsCompressed(sni, commonSetHashes, compressedHashes) } diff --git a/handshake/tags.go b/handshake/tags.go index 6bbb6583..ed96c407 100644 --- a/handshake/tags.go +++ b/handshake/tags.go @@ -17,8 +17,10 @@ const ( TagSNI Tag = 'S' + 'N'<<8 + 'I'<<16 // TagVER is the QUIC version TagVER Tag = 'V' + 'E'<<8 + 'R'<<16 - // TagCCS is the hash of the common certificate sets + // TagCCS are the hashes of the common certificate sets TagCCS Tag = 'C' + 'C'<<8 + 'S'<<16 + // TagCCRT are the hashes of the cached certificates + TagCCRT Tag = 'C' + 'C'<<8 + 'R'<<16 + 'T'<<24 // TagMSPC is max streams per connection TagMSPC Tag = 'M' + 'S'<<8 + 'P'<<16 + 'C'<<24 // TagUAID is the user agent ID