forked from quic-go/quic-go
implement cert compression with cached certificates
This commit is contained in:
131
crypto/cert_compression.go
Normal file
131
crypto/cert_compression.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/zlib"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"hash/fnv"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
)
|
||||
|
||||
type entryType uint8
|
||||
|
||||
const (
|
||||
entryCompressed entryType = 1
|
||||
entryCached entryType = 2
|
||||
entryCommon entryType = 3
|
||||
)
|
||||
|
||||
type entry struct {
|
||||
t entryType
|
||||
h uint64
|
||||
i uint32
|
||||
}
|
||||
|
||||
func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
|
||||
res := &bytes.Buffer{}
|
||||
|
||||
cachedHashes, err := splitHashes(pCachedHashes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chainHashes := make([]uint64, len(chain))
|
||||
for i := range chain {
|
||||
chainHashes[i] = hashCert(chain[i])
|
||||
}
|
||||
|
||||
entries := buildEntries(chain, chainHashes, cachedHashes)
|
||||
|
||||
totalUncompressedLen := 0
|
||||
for i, e := range entries {
|
||||
res.WriteByte(uint8(e.t))
|
||||
switch e.t {
|
||||
case entryCached:
|
||||
utils.WriteUint64(res, chainHashes[i])
|
||||
case entryCompressed:
|
||||
totalUncompressedLen += 4 + len(chain[i])
|
||||
}
|
||||
}
|
||||
res.WriteByte(0) // end of list
|
||||
|
||||
if totalUncompressedLen > 0 {
|
||||
gz, err := zlib.NewWriterLevelDict(res, flate.BestCompression, buildZlibDictForEntries(entries, chain))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
utils.WriteUint32(res, uint32(totalUncompressedLen))
|
||||
|
||||
for i, e := range entries {
|
||||
if e.t != entryCompressed {
|
||||
continue
|
||||
}
|
||||
lenCert := len(chain[i])
|
||||
gz.Write([]byte{
|
||||
byte(lenCert & 0xff),
|
||||
byte((lenCert >> 8) & 0xff),
|
||||
byte((lenCert >> 16) & 0xff),
|
||||
byte((lenCert >> 24) & 0xff),
|
||||
})
|
||||
gz.Write(chain[i])
|
||||
}
|
||||
|
||||
gz.Close()
|
||||
}
|
||||
|
||||
return res.Bytes(), nil
|
||||
}
|
||||
|
||||
func buildEntries(chain [][]byte, chainHashes, cachedHashes []uint64) []entry {
|
||||
res := make([]entry, len(chain))
|
||||
chainLoop:
|
||||
for i := range chain {
|
||||
// Check if hash is in cachedHashes
|
||||
for j := range cachedHashes {
|
||||
if chainHashes[i] == cachedHashes[j] {
|
||||
res[i] = entry{t: entryCached, h: chainHashes[i]}
|
||||
continue chainLoop
|
||||
}
|
||||
}
|
||||
|
||||
res[i] = entry{t: entryCompressed}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func buildZlibDictForEntries(entries []entry, chain [][]byte) []byte {
|
||||
var dict bytes.Buffer
|
||||
|
||||
// First the cached and common in reverse order
|
||||
for i := len(entries) - 1; i >= 0; i-- {
|
||||
if entries[i].t == entryCompressed {
|
||||
continue
|
||||
}
|
||||
dict.Write(chain[i])
|
||||
}
|
||||
|
||||
dict.Write(certDictZlib)
|
||||
return dict.Bytes()
|
||||
}
|
||||
|
||||
func splitHashes(hashes []byte) ([]uint64, error) {
|
||||
if len(hashes)%8 != 0 {
|
||||
return nil, errors.New("expected a multiple of 8 bytes for CCS / CCRT hashes")
|
||||
}
|
||||
n := len(hashes) / 8
|
||||
res := make([]uint64, n)
|
||||
for i := 0; i < n; i++ {
|
||||
res[i] = binary.LittleEndian.Uint64(hashes[i*8 : (i+1)*8])
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func hashCert(cert []byte) uint64 {
|
||||
h := fnv.New64()
|
||||
h.Write(cert)
|
||||
return h.Sum64()
|
||||
}
|
||||
98
crypto/cert_compression_test.go
Normal file
98
crypto/cert_compression_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/zlib"
|
||||
"encoding/binary"
|
||||
"hash/fnv"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func byteHash(d []byte) []byte {
|
||||
h := fnv.New64()
|
||||
h.Write(d)
|
||||
s := h.Sum64()
|
||||
res := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(res, s)
|
||||
return res
|
||||
}
|
||||
|
||||
var _ = Describe("Cert compression", func() {
|
||||
It("compresses empty", func() {
|
||||
compressed, err := compressChain(nil, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(compressed).To(Equal([]byte{0}))
|
||||
})
|
||||
|
||||
It("gives correct single cert", func() {
|
||||
cert := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
certZlib := &bytes.Buffer{}
|
||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
||||
z.Write(cert)
|
||||
z.Close()
|
||||
chain := [][]byte{cert}
|
||||
compressed, err := compressChain(chain, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(compressed).To(Equal(append([]byte{
|
||||
0x01, 0x00,
|
||||
0x08, 0x00, 0x00, 0x00,
|
||||
}, certZlib.Bytes()...)))
|
||||
})
|
||||
|
||||
It("gives correct cert and intermediate", func() {
|
||||
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||
certZlib := &bytes.Buffer{}
|
||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
||||
z.Write(cert1)
|
||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
||||
z.Write(cert2)
|
||||
z.Close()
|
||||
chain := [][]byte{cert1, cert2}
|
||||
compressed, err := compressChain(chain, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(compressed).To(Equal(append([]byte{
|
||||
0x01, 0x01, 0x00,
|
||||
0x10, 0x00, 0x00, 0x00,
|
||||
}, certZlib.Bytes()...)))
|
||||
})
|
||||
|
||||
It("uses cached certificates", func() {
|
||||
cert := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
certHash := byteHash(cert)
|
||||
chain := [][]byte{cert}
|
||||
compressed, err := compressChain(chain, nil, certHash)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expected := append([]byte{0x02}, certHash...)
|
||||
expected = append(expected, 0x00)
|
||||
Expect(compressed).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("uses cached certificates and compressed combined", func() {
|
||||
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||
cert2Hash := byteHash(cert2)
|
||||
certZlib := &bytes.Buffer{}
|
||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, append(cert2, certDictZlib...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
||||
z.Write(cert1)
|
||||
z.Close()
|
||||
chain := [][]byte{cert1, cert2}
|
||||
compressed, err := compressChain(chain, nil, cert2Hash)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expected := []byte{0x01, 0x02}
|
||||
expected = append(expected, cert2Hash...)
|
||||
expected = append(expected, 0x00)
|
||||
expected = append(expected, []byte{0x08, 0, 0, 0}...)
|
||||
expected = append(expected, certZlib.Bytes()...)
|
||||
Expect(compressed).To(Equal(expected))
|
||||
})
|
||||
})
|
||||
@@ -1,9 +1,6 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/zlib"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -11,8 +8,6 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
)
|
||||
|
||||
// rsaSigner stores a key and a certificate for the server proof
|
||||
@@ -57,39 +52,12 @@ func (kd *rsaSigner) SignServerProof(sni string, chlo []byte, serverConfigData [
|
||||
}
|
||||
|
||||
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
|
||||
func (kd *rsaSigner) GetCertsCompressed(sni string) ([]byte, error) {
|
||||
func (kd *rsaSigner) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
|
||||
cert, err := kd.getCertForSNI(sni)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
totalUncompressedLen := 0
|
||||
for _, c := range cert.Certificate {
|
||||
// Entry type compressed
|
||||
b.WriteByte(1)
|
||||
totalUncompressedLen += len(c)
|
||||
}
|
||||
// Entry type end_of_list
|
||||
b.WriteByte(0)
|
||||
// Data + individual lengths as uint32
|
||||
utils.WriteUint32(b, uint32(totalUncompressedLen+4*len(cert.Certificate)))
|
||||
gz, err := zlib.NewWriterLevelDict(b, flate.BestCompression, certDictZlib)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, c := range cert.Certificate {
|
||||
lenCert := len(c)
|
||||
gz.Write([]byte{
|
||||
byte(lenCert & 0xff),
|
||||
byte((lenCert >> 8) & 0xff),
|
||||
byte((lenCert >> 16) & 0xff),
|
||||
byte((lenCert >> 24) & 0xff),
|
||||
})
|
||||
gz.Write(c)
|
||||
}
|
||||
gz.Close()
|
||||
return b.Bytes(), nil
|
||||
return compressChain(cert.Certificate, pCommonSetHashes, pCachedHashes)
|
||||
}
|
||||
|
||||
// GetLeafCert gets the leaf certificate
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
var _ = Describe("ProofRsa", func() {
|
||||
It("gives correct cert", func() {
|
||||
It("compresses certs", func() {
|
||||
cert := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
certZlib := &bytes.Buffer{}
|
||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
|
||||
@@ -30,7 +30,7 @@ var _ = Describe("ProofRsa", func() {
|
||||
},
|
||||
},
|
||||
}
|
||||
certCompressed, err := kd.GetCertsCompressed("")
|
||||
certCompressed, err := kd.GetCertsCompressed("", nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(certCompressed).To(Equal(append([]byte{
|
||||
0x01, 0x00,
|
||||
@@ -38,32 +38,6 @@ var _ = Describe("ProofRsa", func() {
|
||||
}, certZlib.Bytes()...)))
|
||||
})
|
||||
|
||||
It("gives correct cert with intermediate", func() {
|
||||
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||
certZlib := &bytes.Buffer{}
|
||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
||||
z.Write(cert1)
|
||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
||||
z.Write(cert2)
|
||||
z.Close()
|
||||
kd := &rsaSigner{
|
||||
config: &tls.Config{
|
||||
Certificates: []tls.Certificate{
|
||||
tls.Certificate{Certificate: [][]byte{cert1, cert2}},
|
||||
},
|
||||
},
|
||||
}
|
||||
certCompressed, err := kd.GetCertsCompressed("")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(certCompressed).To(Equal(append([]byte{
|
||||
0x01, 0x01, 0x00,
|
||||
0x10, 0x00, 0x00, 0x00,
|
||||
}, certZlib.Bytes()...)))
|
||||
})
|
||||
|
||||
It("gives valid signatures", func() {
|
||||
key := testdata.GetTLSConfig().Certificates[0].PrivateKey.(*rsa.PrivateKey).Public().(*rsa.PublicKey)
|
||||
kd, err := NewRSASigner(testdata.GetTLSConfig())
|
||||
|
||||
@@ -3,6 +3,6 @@ package crypto
|
||||
// A Signer holds a certificate and a private key
|
||||
type Signer interface {
|
||||
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error)
|
||||
GetCertsCompressed(sni string) ([]byte, error)
|
||||
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error)
|
||||
GetLeafCert(sni string) ([]byte, error)
|
||||
}
|
||||
|
||||
@@ -99,7 +99,7 @@ func (h *CryptoSetup) HandleCryptoStream() error {
|
||||
}
|
||||
|
||||
// We have an inchoate or non-matching CHLO, we now send a rejection
|
||||
reply, err = h.handleInchoateCHLO(sni, chloData)
|
||||
reply, err = h.handleInchoateCHLO(sni, chloData, cryptoData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -164,7 +164,7 @@ func (h *CryptoSetup) isInchoateCHLO(cryptoData map[Tag][]byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *CryptoSetup) handleInchoateCHLO(sni string, data []byte) ([]byte, error) {
|
||||
func (h *CryptoSetup) handleInchoateCHLO(sni string, data []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
||||
var chloOrNil []byte
|
||||
if h.version > protocol.VersionNumber(30) {
|
||||
chloOrNil = data
|
||||
@@ -175,7 +175,10 @@ func (h *CryptoSetup) handleInchoateCHLO(sni string, data []byte) ([]byte, error
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certCompressed, err := h.scfg.GetCertsCompressed(sni)
|
||||
commonSetHashes := cryptoData[TagCCS]
|
||||
cachedCertsHashes := cryptoData[TagCCRT]
|
||||
|
||||
certCompressed, err := h.scfg.GetCertsCompressed(sni, commonSetHashes, cachedCertsHashes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (s *mockSigner) SignServerProof(sni string, chlo []byte, serverConfigData [
|
||||
}
|
||||
return []byte("proof"), nil
|
||||
}
|
||||
func (*mockSigner) GetCertsCompressed(sni string) ([]byte, error) {
|
||||
func (*mockSigner) GetCertsCompressed(sni string, common, cached []byte) ([]byte, error) {
|
||||
return []byte("certcompressed"), nil
|
||||
}
|
||||
func (*mockSigner) GetLeafCert(sni string) ([]byte, error) {
|
||||
@@ -125,7 +125,7 @@ var _ = Describe("Crypto setup", func() {
|
||||
|
||||
Context("when responding to client messages", func() {
|
||||
It("generates REJ messages", func() {
|
||||
response, err := cs.handleInchoateCHLO("", []byte("chlo"))
|
||||
response, err := cs.handleInchoateCHLO("", []byte("chlo"), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response).To(HavePrefix("REJ"))
|
||||
Expect(response).To(ContainSubstring("certcompressed"))
|
||||
@@ -135,7 +135,7 @@ var _ = Describe("Crypto setup", func() {
|
||||
|
||||
It("generates REJ messages for version 30", func() {
|
||||
cs.version = protocol.VersionNumber(30)
|
||||
_, err := cs.handleInchoateCHLO("", sampleCHLO)
|
||||
_, err := cs.handleInchoateCHLO("", sampleCHLO, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(signer.gotCHLO).To(BeFalse())
|
||||
})
|
||||
|
||||
@@ -50,6 +50,6 @@ func (s *ServerConfig) Sign(sni string, chlo []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
// GetCertsCompressed returns the certificate data
|
||||
func (s *ServerConfig) GetCertsCompressed(sni string) ([]byte, error) {
|
||||
return s.signer.GetCertsCompressed(sni)
|
||||
func (s *ServerConfig) GetCertsCompressed(sni string, commonSetHashes, compressedHashes []byte) ([]byte, error) {
|
||||
return s.signer.GetCertsCompressed(sni, commonSetHashes, compressedHashes)
|
||||
}
|
||||
|
||||
@@ -17,8 +17,10 @@ const (
|
||||
TagSNI Tag = 'S' + 'N'<<8 + 'I'<<16
|
||||
// TagVER is the QUIC version
|
||||
TagVER Tag = 'V' + 'E'<<8 + 'R'<<16
|
||||
// TagCCS is the hash of the common certificate sets
|
||||
// TagCCS are the hashes of the common certificate sets
|
||||
TagCCS Tag = 'C' + 'C'<<8 + 'S'<<16
|
||||
// TagCCRT are the hashes of the cached certificates
|
||||
TagCCRT Tag = 'C' + 'C'<<8 + 'R'<<16 + 'T'<<24
|
||||
// TagMSPC is max streams per connection
|
||||
TagMSPC Tag = 'M' + 'S'<<8 + 'P'<<16 + 'C'<<24
|
||||
// TagUAID is the user agent ID
|
||||
|
||||
Reference in New Issue
Block a user