diff --git a/crypto/cert_cache.go b/crypto/cert_cache.go index 189c5644..3ebdc1ae 100644 --- a/crypto/cert_cache.go +++ b/crypto/cert_cache.go @@ -1,42 +1,48 @@ package crypto import ( + "fmt" "hash/fnv" - "sync" + + "github.com/hashicorp/golang-lru" + "github.com/lucas-clemente/quic-go/protocol" ) var ( - compressedCertsCache = map[uint64][]byte{} - compressedCertsCacheMutex sync.RWMutex + compressedCertsCache *lru.Cache ) func getCompressedCert(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) { // Hash all inputs - hash := fnv.New64a() + hasher := fnv.New64a() for _, v := range chain { - hash.Write(v) + hasher.Write(v) } - hash.Write(pCommonSetHashes) - hash.Write(pCachedHashes) - hashRes := hash.Sum64() + hasher.Write(pCommonSetHashes) + hasher.Write(pCachedHashes) + hash := hasher.Sum64() - compressedCertsCacheMutex.RLock() - result, isCached := compressedCertsCache[hashRes] - compressedCertsCacheMutex.RUnlock() + var result []byte + + resultI, isCached := compressedCertsCache.Get(hash) if isCached { - return result, nil + result = resultI.([]byte) + } else { + var err error + result, err = compressChain(chain, pCommonSetHashes, pCachedHashes) + if err != nil { + return nil, err + } + compressedCertsCache.Add(hash, result) } - compressedCertsCacheMutex.Lock() - defer compressedCertsCacheMutex.Unlock() - result, isCached = compressedCertsCache[hashRes] - if isCached { - return result, nil - } - cached, err := compressChain(chain, pCommonSetHashes, pCachedHashes) - if err != nil { - return nil, err - } - compressedCertsCache[hashRes] = cached - return cached, nil + return result, nil +} + +func init() { + var err error + compressedCertsCache, err = lru.New(protocol.NumCachedCertificates) + if err != nil { + panic(fmt.Sprintf("fatal error in quic-go: could not create lru cache: %s", err.Error())) + } } diff --git a/crypto/cert_cache_test.go b/crypto/cert_cache_test.go index 3725b086..1ecc26f6 100644 --- a/crypto/cert_cache_test.go +++ b/crypto/cert_cache_test.go @@ -1,13 +1,16 @@ package crypto import ( + lru "github.com/hashicorp/golang-lru" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("Certificate cache", func() { BeforeEach(func() { - compressedCertsCache = map[uint64][]byte{} + var err error + compressedCertsCache, err = lru.New(2) + Expect(err).NotTo(HaveOccurred()) }) It("gives a compressed cert", func() { @@ -29,11 +32,20 @@ var _ = Describe("Certificate cache", func() { }) It("stores cached values", func() { - Expect(compressedCertsCache).To(HaveLen(0)) chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}} - compressed, err := getCompressedCert(chain, nil, nil) + _, err := getCompressedCert(chain, nil, nil) Expect(err).NotTo(HaveOccurred()) - Expect(compressedCertsCache).To(HaveLen(1)) - Expect(compressedCertsCache[3838929964809501833]).To(Equal(compressed)) + Expect(compressedCertsCache.Len()).To(Equal(1)) + Expect(compressedCertsCache.Contains(uint64(3838929964809501833))).To(BeTrue()) + }) + + It("evicts old values", func() { + _, err := getCompressedCert([][]byte{{0x00}}, nil, nil) + Expect(err).NotTo(HaveOccurred()) + _, err = getCompressedCert([][]byte{{0x01}}, nil, nil) + Expect(err).NotTo(HaveOccurred()) + _, err = getCompressedCert([][]byte{{0x02}}, nil, nil) + Expect(err).NotTo(HaveOccurred()) + Expect(compressedCertsCache.Len()).To(Equal(2)) }) }) diff --git a/protocol/server_parameters.go b/protocol/server_parameters.go index 661e0319..99894a0e 100644 --- a/protocol/server_parameters.go +++ b/protocol/server_parameters.go @@ -85,3 +85,6 @@ const MaxIdleTimeout = 1 * time.Minute // MaxTimeForCryptoHandshake is the default timeout for a connection until the crypto handshake succeeds. const MaxTimeForCryptoHandshake = 10 * time.Second + +// NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space +const NumCachedCertificates = 128