forked from quic-go/quic-go
@@ -1,42 +1,48 @@
|
|||||||
package crypto
|
package crypto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"sync"
|
|
||||||
|
"github.com/hashicorp/golang-lru"
|
||||||
|
"github.com/lucas-clemente/quic-go/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
compressedCertsCache = map[uint64][]byte{}
|
compressedCertsCache *lru.Cache
|
||||||
compressedCertsCacheMutex sync.RWMutex
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func getCompressedCert(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
|
func getCompressedCert(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
|
||||||
// Hash all inputs
|
// Hash all inputs
|
||||||
hash := fnv.New64a()
|
hasher := fnv.New64a()
|
||||||
for _, v := range chain {
|
for _, v := range chain {
|
||||||
hash.Write(v)
|
hasher.Write(v)
|
||||||
}
|
}
|
||||||
hash.Write(pCommonSetHashes)
|
hasher.Write(pCommonSetHashes)
|
||||||
hash.Write(pCachedHashes)
|
hasher.Write(pCachedHashes)
|
||||||
hashRes := hash.Sum64()
|
hash := hasher.Sum64()
|
||||||
|
|
||||||
compressedCertsCacheMutex.RLock()
|
var result []byte
|
||||||
result, isCached := compressedCertsCache[hashRes]
|
|
||||||
compressedCertsCacheMutex.RUnlock()
|
|
||||||
if isCached {
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
compressedCertsCacheMutex.Lock()
|
resultI, isCached := compressedCertsCache.Get(hash)
|
||||||
defer compressedCertsCacheMutex.Unlock()
|
|
||||||
result, isCached = compressedCertsCache[hashRes]
|
|
||||||
if isCached {
|
if isCached {
|
||||||
return result, nil
|
result = resultI.([]byte)
|
||||||
}
|
} else {
|
||||||
cached, err := compressChain(chain, pCommonSetHashes, pCachedHashes)
|
var err error
|
||||||
|
result, err = compressChain(chain, pCommonSetHashes, pCachedHashes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
compressedCertsCache[hashRes] = cached
|
compressedCertsCache.Add(hash, result)
|
||||||
return cached, nil
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
compressedCertsCache, err = lru.New(protocol.NumCachedCertificates)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("fatal error in quic-go: could not create lru cache: %s", err.Error()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
package crypto
|
package crypto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
lru "github.com/hashicorp/golang-lru"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Certificate cache", func() {
|
var _ = Describe("Certificate cache", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
compressedCertsCache = map[uint64][]byte{}
|
var err error
|
||||||
|
compressedCertsCache, err = lru.New(2)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("gives a compressed cert", func() {
|
It("gives a compressed cert", func() {
|
||||||
@@ -29,11 +32,20 @@ var _ = Describe("Certificate cache", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("stores cached values", func() {
|
It("stores cached values", func() {
|
||||||
Expect(compressedCertsCache).To(HaveLen(0))
|
|
||||||
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
|
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
|
||||||
compressed, err := getCompressedCert(chain, nil, nil)
|
_, err := getCompressedCert(chain, nil, nil)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(compressedCertsCache).To(HaveLen(1))
|
Expect(compressedCertsCache.Len()).To(Equal(1))
|
||||||
Expect(compressedCertsCache[3838929964809501833]).To(Equal(compressed))
|
Expect(compressedCertsCache.Contains(uint64(3838929964809501833))).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("evicts old values", func() {
|
||||||
|
_, err := getCompressedCert([][]byte{{0x00}}, nil, nil)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
_, err = getCompressedCert([][]byte{{0x01}}, nil, nil)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
_, err = getCompressedCert([][]byte{{0x02}}, nil, nil)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(compressedCertsCache.Len()).To(Equal(2))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -85,3 +85,6 @@ const MaxIdleTimeout = 1 * time.Minute
|
|||||||
|
|
||||||
// MaxTimeForCryptoHandshake is the default timeout for a connection until the crypto handshake succeeds.
|
// MaxTimeForCryptoHandshake is the default timeout for a connection until the crypto handshake succeeds.
|
||||||
const MaxTimeForCryptoHandshake = 10 * time.Second
|
const MaxTimeForCryptoHandshake = 10 * time.Second
|
||||||
|
|
||||||
|
// NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space
|
||||||
|
const NumCachedCertificates = 128
|
||||||
|
|||||||
Reference in New Issue
Block a user