forked from quic-go/quic-go
Merge pull request #2125 from lucas-clemente/chacha-header-protection
implement ChaCha20 header protection
This commit is contained in:
@@ -1,32 +1,28 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
type sealer struct {
|
||||
aead cipher.AEAD
|
||||
hpEncrypter cipher.Block
|
||||
aead cipher.AEAD
|
||||
headerProtector headerProtector
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
hpMask []byte
|
||||
}
|
||||
|
||||
var _ LongHeaderSealer = &sealer{}
|
||||
|
||||
func newLongHeaderSealer(aead cipher.AEAD, hpEncrypter cipher.Block) LongHeaderSealer {
|
||||
func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer {
|
||||
return &sealer{
|
||||
aead: aead,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
hpEncrypter: hpEncrypter,
|
||||
hpMask: make([]byte, hpEncrypter.BlockSize()),
|
||||
aead: aead,
|
||||
headerProtector: headerProtector,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,14 +34,7 @@ func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []by
|
||||
}
|
||||
|
||||
func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
if len(sample) != s.hpEncrypter.BlockSize() {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
s.hpEncrypter.Encrypt(s.hpMask, sample)
|
||||
*firstByte ^= s.hpMask[0] & 0xf
|
||||
for i := range pnBytes {
|
||||
pnBytes[i] ^= s.hpMask[i+1]
|
||||
}
|
||||
s.headerProtector.EncryptHeader(sample, firstByte, pnBytes)
|
||||
}
|
||||
|
||||
func (s *sealer) Overhead() int {
|
||||
@@ -53,22 +42,20 @@ func (s *sealer) Overhead() int {
|
||||
}
|
||||
|
||||
type longHeaderOpener struct {
|
||||
aead cipher.AEAD
|
||||
pnDecrypter cipher.Block
|
||||
aead cipher.AEAD
|
||||
headerProtector headerProtector
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
hpMask []byte
|
||||
}
|
||||
|
||||
var _ LongHeaderOpener = &longHeaderOpener{}
|
||||
|
||||
func newLongHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) LongHeaderOpener {
|
||||
func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener {
|
||||
return &longHeaderOpener{
|
||||
aead: aead,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
pnDecrypter: pnDecrypter,
|
||||
hpMask: make([]byte, pnDecrypter.BlockSize()),
|
||||
aead: aead,
|
||||
headerProtector: headerProtector,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,27 +71,11 @@ func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []
|
||||
}
|
||||
|
||||
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
if len(sample) != o.pnDecrypter.BlockSize() {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
o.pnDecrypter.Encrypt(o.hpMask, sample)
|
||||
*firstByte ^= o.hpMask[0] & 0xf
|
||||
for i := range pnBytes {
|
||||
pnBytes[i] ^= o.hpMask[i+1]
|
||||
}
|
||||
o.headerProtector.DecryptHeader(sample, firstByte, pnBytes)
|
||||
}
|
||||
|
||||
func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD {
|
||||
key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic key", suite.KeyLen())
|
||||
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen())
|
||||
func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) cipher.AEAD {
|
||||
key := qtls.HkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic key", suite.KeyLen)
|
||||
iv := qtls.HkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic iv", suite.IVLen())
|
||||
return suite.AEAD(key, iv)
|
||||
}
|
||||
|
||||
func createHeaderProtector(suite cipherSuite, trafficSecret []byte) cipher.Block {
|
||||
hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen())
|
||||
hp, err := aes.NewCipher(hpKey)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
||||
}
|
||||
return hp
|
||||
}
|
||||
|
||||
@@ -4,86 +4,90 @@ import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("AEAD", func() {
|
||||
getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) {
|
||||
key := make([]byte, 16)
|
||||
hpKey := make([]byte, 16)
|
||||
rand.Read(key)
|
||||
rand.Read(hpKey)
|
||||
block, err := aes.NewCipher(key)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
aead, err := cipher.NewGCM(block)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hpBlock, err := aes.NewCipher(hpKey)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := range cipherSuites {
|
||||
cs := cipherSuites[i]
|
||||
|
||||
iv := make([]byte, 12)
|
||||
rand.Read(iv)
|
||||
return newLongHeaderSealer(aead, hpBlock), newLongHeaderOpener(aead, hpBlock)
|
||||
}
|
||||
Context(fmt.Sprintf("using %s", cipherSuiteName(cs.ID)), func() {
|
||||
getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) {
|
||||
key := make([]byte, 16)
|
||||
hpKey := make([]byte, 16)
|
||||
rand.Read(key)
|
||||
rand.Read(hpKey)
|
||||
block, err := aes.NewCipher(key)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
aead, err := cipher.NewGCM(block)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Context("message encryption", func() {
|
||||
msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||
ad := []byte("Donec in velit neque.")
|
||||
|
||||
It("encrypts and decrypts a message", func() {
|
||||
sealer, opener := getSealerAndOpener()
|
||||
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
|
||||
opened, err := opener.Open(nil, encrypted, 0x1337, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(opened).To(Equal(msg))
|
||||
})
|
||||
|
||||
It("fails to open a message if the associated data is not the same", func() {
|
||||
sealer, opener := getSealerAndOpener()
|
||||
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad"))
|
||||
Expect(err).To(MatchError(ErrDecryptionFailed))
|
||||
})
|
||||
|
||||
It("fails to open a message if the packet number is not the same", func() {
|
||||
sealer, opener := getSealerAndOpener()
|
||||
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := opener.Open(nil, encrypted, 0x42, ad)
|
||||
Expect(err).To(MatchError(ErrDecryptionFailed))
|
||||
})
|
||||
})
|
||||
|
||||
Context("header encryption", func() {
|
||||
It("encrypts and encrypts the header", func() {
|
||||
sealer, opener := getSealerAndOpener()
|
||||
var lastFourBitsDifferent int
|
||||
for i := 0; i < 100; i++ {
|
||||
sample := make([]byte, 16)
|
||||
rand.Read(sample)
|
||||
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
|
||||
sealer.EncryptHeader(sample, &header[0], header[9:13])
|
||||
if header[0]&0xf != 0xb5&0xf {
|
||||
lastFourBitsDifferent++
|
||||
}
|
||||
Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0)))
|
||||
Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
|
||||
Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
|
||||
opener.DecryptHeader(sample, &header[0], header[9:13])
|
||||
Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}))
|
||||
return newLongHeaderSealer(aead, newHeaderProtector(cs, key, true)),
|
||||
newLongHeaderOpener(aead, newHeaderProtector(cs, key, true))
|
||||
}
|
||||
Expect(lastFourBitsDifferent).To(BeNumerically(">", 75))
|
||||
})
|
||||
|
||||
It("fails to decrypt the header when using a different sample", func() {
|
||||
sealer, opener := getSealerAndOpener()
|
||||
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
|
||||
sample := make([]byte, 16)
|
||||
rand.Read(sample)
|
||||
sealer.EncryptHeader(sample, &header[0], header[9:13])
|
||||
rand.Read(sample) // use a different sample
|
||||
opener.DecryptHeader(sample, &header[0], header[9:13])
|
||||
Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}))
|
||||
Context("message encryption", func() {
|
||||
msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||
ad := []byte("Donec in velit neque.")
|
||||
|
||||
It("encrypts and decrypts a message", func() {
|
||||
sealer, opener := getSealerAndOpener()
|
||||
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
|
||||
opened, err := opener.Open(nil, encrypted, 0x1337, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(opened).To(Equal(msg))
|
||||
})
|
||||
|
||||
It("fails to open a message if the associated data is not the same", func() {
|
||||
sealer, opener := getSealerAndOpener()
|
||||
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad"))
|
||||
Expect(err).To(MatchError(ErrDecryptionFailed))
|
||||
})
|
||||
|
||||
It("fails to open a message if the packet number is not the same", func() {
|
||||
sealer, opener := getSealerAndOpener()
|
||||
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := opener.Open(nil, encrypted, 0x42, ad)
|
||||
Expect(err).To(MatchError(ErrDecryptionFailed))
|
||||
})
|
||||
})
|
||||
|
||||
Context("header encryption", func() {
|
||||
It("encrypts and encrypts the header", func() {
|
||||
sealer, opener := getSealerAndOpener()
|
||||
var lastFourBitsDifferent int
|
||||
for i := 0; i < 100; i++ {
|
||||
sample := make([]byte, 16)
|
||||
rand.Read(sample)
|
||||
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
|
||||
sealer.EncryptHeader(sample, &header[0], header[9:13])
|
||||
if header[0]&0xf != 0xb5&0xf {
|
||||
lastFourBitsDifferent++
|
||||
}
|
||||
Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0)))
|
||||
Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
|
||||
Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
|
||||
opener.DecryptHeader(sample, &header[0], header[9:13])
|
||||
Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}))
|
||||
}
|
||||
Expect(lastFourBitsDifferent).To(BeNumerically(">", 75))
|
||||
})
|
||||
|
||||
It("fails to decrypt the header when using a different sample", func() {
|
||||
sealer, opener := getSealerAndOpener()
|
||||
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
|
||||
sample := make([]byte, 16)
|
||||
rand.Read(sample)
|
||||
sealer.EncryptHeader(sample, &header[0], header[9:13])
|
||||
rand.Read(sample) // use a different sample
|
||||
opener.DecryptHeader(sample, &header[0], header[9:13])
|
||||
Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
@@ -491,21 +491,21 @@ func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||
func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||
h.mutex.Lock()
|
||||
switch encLevel {
|
||||
case qtls.EncryptionHandshake:
|
||||
h.readEncLevel = protocol.EncryptionHandshake
|
||||
h.handshakeOpener = newLongHeaderOpener(
|
||||
createAEAD(suite, trafficSecret),
|
||||
createHeaderProtector(suite, trafficSecret),
|
||||
newHeaderProtector(suite, trafficSecret, true),
|
||||
)
|
||||
h.logger.Debugf("Installed Handshake Read keys")
|
||||
h.logger.Debugf("Installed Handshake Read keys (using %s)", cipherSuiteName(suite.ID))
|
||||
case qtls.EncryptionApplication:
|
||||
h.readEncLevel = protocol.Encryption1RTT
|
||||
h.aead.SetReadKey(suite, trafficSecret)
|
||||
h.has1RTTOpener = true
|
||||
h.logger.Debugf("Installed 1-RTT Read keys")
|
||||
h.logger.Debugf("Installed 1-RTT Read keys (using %s)", cipherSuiteName(suite.ID))
|
||||
default:
|
||||
panic("unexpected read encryption level")
|
||||
}
|
||||
@@ -513,21 +513,21 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
|
||||
h.receivedReadKey <- struct{}{}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||
h.mutex.Lock()
|
||||
switch encLevel {
|
||||
case qtls.EncryptionHandshake:
|
||||
h.writeEncLevel = protocol.EncryptionHandshake
|
||||
h.handshakeSealer = newLongHeaderSealer(
|
||||
createAEAD(suite, trafficSecret),
|
||||
createHeaderProtector(suite, trafficSecret),
|
||||
newHeaderProtector(suite, trafficSecret, true),
|
||||
)
|
||||
h.logger.Debugf("Installed Handshake Write keys")
|
||||
h.logger.Debugf("Installed Handshake Write keys (using %s)", cipherSuiteName(suite.ID))
|
||||
case qtls.EncryptionApplication:
|
||||
h.writeEncLevel = protocol.Encryption1RTT
|
||||
h.aead.SetWriteKey(suite, trafficSecret)
|
||||
h.has1RTTSealer = true
|
||||
h.logger.Debugf("Installed 1-RTT Write keys")
|
||||
h.logger.Debugf("Installed 1-RTT Write keys (using %s)", cipherSuiteName(suite.ID))
|
||||
default:
|
||||
panic("unexpected write encryption level")
|
||||
}
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
|
||||
"github.com/alangpierce/go-forceexport"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/marten-seemann/qtls"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestQuicGo(t *testing.T) {
|
||||
func TestHandshake(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Handshake Suite")
|
||||
}
|
||||
@@ -22,3 +28,37 @@ var _ = BeforeEach(func() {
|
||||
var _ = AfterEach(func() {
|
||||
mockCtrl.Finish()
|
||||
})
|
||||
|
||||
var aeadChaCha20Poly1305 func(key, nonceMask []byte) cipher.AEAD
|
||||
|
||||
var cipherSuites = []*qtls.CipherSuiteTLS13{
|
||||
&qtls.CipherSuiteTLS13{
|
||||
ID: qtls.TLS_AES_128_GCM_SHA256,
|
||||
KeyLen: 16,
|
||||
AEAD: qtls.AEADAESGCMTLS13,
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
&qtls.CipherSuiteTLS13{
|
||||
ID: qtls.TLS_AES_256_GCM_SHA384,
|
||||
KeyLen: 32,
|
||||
AEAD: qtls.AEADAESGCMTLS13,
|
||||
Hash: crypto.SHA384,
|
||||
},
|
||||
&qtls.CipherSuiteTLS13{
|
||||
ID: qtls.TLS_CHACHA20_POLY1305_SHA256,
|
||||
KeyLen: 32,
|
||||
AEAD: nil, // will be set by init
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
if err := forceexport.GetFunc(&aeadChaCha20Poly1305, "github.com/marten-seemann/qtls.aeadChaCha20Poly1305"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, s := range cipherSuites {
|
||||
if s.ID == qtls.TLS_CHACHA20_POLY1305_SHA256 {
|
||||
s.AEAD = aeadChaCha20Poly1305
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
119
internal/handshake/header_protector.go
Normal file
119
internal/handshake/header_protector.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
|
||||
"github.com/marten-seemann/chacha20"
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
type headerProtector interface {
|
||||
EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||
}
|
||||
|
||||
func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector {
|
||||
switch suite.ID {
|
||||
case qtls.TLS_AES_128_GCM_SHA256, qtls.TLS_AES_256_GCM_SHA384:
|
||||
return newAESHeaderProtector(suite, trafficSecret, isLongHeader)
|
||||
case qtls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader)
|
||||
default:
|
||||
panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type aesHeaderProtector struct {
|
||||
mask []byte
|
||||
block cipher.Block
|
||||
isLongHeader bool
|
||||
}
|
||||
|
||||
var _ headerProtector = &aesHeaderProtector{}
|
||||
|
||||
func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector {
|
||||
hpKey := qtls.HkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic hp", suite.KeyLen)
|
||||
block, err := aes.NewCipher(hpKey)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
||||
}
|
||||
return &aesHeaderProtector{
|
||||
block: block,
|
||||
mask: make([]byte, block.BlockSize()),
|
||||
isLongHeader: isLongHeader,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
if len(sample) != len(p.mask) {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
p.block.Encrypt(p.mask, sample)
|
||||
if p.isLongHeader {
|
||||
*firstByte ^= p.mask[0] & 0xf
|
||||
} else {
|
||||
*firstByte ^= p.mask[0] & 0x1f
|
||||
}
|
||||
for i := range hdrBytes {
|
||||
hdrBytes[i] ^= p.mask[i+1]
|
||||
}
|
||||
}
|
||||
|
||||
type chachaHeaderProtector struct {
|
||||
mask [5]byte
|
||||
|
||||
key [32]byte
|
||||
sampleBuf [16]byte
|
||||
isLongHeader bool
|
||||
}
|
||||
|
||||
var _ headerProtector = &chachaHeaderProtector{}
|
||||
|
||||
func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector {
|
||||
hpKey := qtls.HkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic hp", suite.KeyLen)
|
||||
|
||||
p := &chachaHeaderProtector{
|
||||
isLongHeader: isLongHeader,
|
||||
}
|
||||
copy(p.key[:], hpKey)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
if len(sample) < len(p.mask) {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
p.mask[i] = 0
|
||||
}
|
||||
copy(p.sampleBuf[:], sample)
|
||||
chacha20.XORKeyStream(p.mask[:], p.mask[:], &p.sampleBuf, &p.key)
|
||||
|
||||
if p.isLongHeader {
|
||||
*firstByte ^= p.mask[0] & 0xf
|
||||
} else {
|
||||
*firstByte ^= p.mask[0] & 0x1f
|
||||
}
|
||||
for i := range hdrBytes {
|
||||
hdrBytes[i] ^= p.mask[i+1]
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/marten-seemann/qtls"
|
||||
@@ -10,6 +9,13 @@ import (
|
||||
|
||||
var quicVersion1Salt = []byte{0x7f, 0xbc, 0xdb, 0x0e, 0x7c, 0x66, 0xbb, 0xe9, 0x19, 0x3a, 0x96, 0xcd, 0x21, 0x51, 0x9e, 0xbd, 0x7a, 0x02, 0x64, 0x4a}
|
||||
|
||||
var initialSuite = &qtls.CipherSuiteTLS13{
|
||||
ID: qtls.TLS_AES_128_GCM_SHA256,
|
||||
KeyLen: 16,
|
||||
AEAD: qtls.AEADAESGCMTLS13,
|
||||
Hash: crypto.SHA256,
|
||||
}
|
||||
|
||||
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
|
||||
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, LongHeaderOpener, error) {
|
||||
clientSecret, serverSecret := computeSecrets(connID)
|
||||
@@ -21,20 +27,15 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo
|
||||
mySecret = serverSecret
|
||||
otherSecret = clientSecret
|
||||
}
|
||||
myKey, myHPKey, myIV := computeInitialKeyAndIV(mySecret)
|
||||
otherKey, otherHPKey, otherIV := computeInitialKeyAndIV(otherSecret)
|
||||
myKey, myIV := computeInitialKeyAndIV(mySecret)
|
||||
otherKey, otherIV := computeInitialKeyAndIV(otherSecret)
|
||||
|
||||
encrypter := qtls.AEADAESGCMTLS13(myKey, myIV)
|
||||
hpEncrypter, err := aes.NewCipher(myHPKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
|
||||
hpDecrypter, err := aes.NewCipher(otherHPKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return newLongHeaderSealer(encrypter, hpEncrypter), newLongHeaderOpener(decrypter, hpDecrypter), nil
|
||||
|
||||
return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true)),
|
||||
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true)),
|
||||
nil
|
||||
}
|
||||
|
||||
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||
@@ -44,9 +45,8 @@ func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []
|
||||
return
|
||||
}
|
||||
|
||||
func computeInitialKeyAndIV(secret []byte) (key, hpKey, iv []byte) {
|
||||
func computeInitialKeyAndIV(secret []byte) (key, iv []byte) {
|
||||
key = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
|
||||
hpKey = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic hp", 16)
|
||||
iv = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -42,19 +42,17 @@ var _ = Describe("Initial AEAD using AES-GCM", func() {
|
||||
It("computes the client key and IV", func() {
|
||||
clientSecret, _ := computeSecrets(connID)
|
||||
Expect(clientSecret).To(Equal(split("8a3515a14ae3c31b9c2d6d5bc58538ca 5cd2baa119087143e60887428dcb52f6")))
|
||||
key, hpKey, iv := computeInitialKeyAndIV(clientSecret)
|
||||
key, iv := computeInitialKeyAndIV(clientSecret)
|
||||
Expect(key).To(Equal(split("98b0d7e5e7a402c67c33f350fa65ea54")))
|
||||
Expect(iv).To(Equal(split("19e94387805eb0b46c03a788")))
|
||||
Expect(hpKey).To(Equal(split("0edd982a6ac527f2eddcbb7348dea5d7")))
|
||||
})
|
||||
|
||||
It("computes the server key and IV", func() {
|
||||
_, serverSecret := computeSecrets(connID)
|
||||
Expect(serverSecret).To(Equal(split("47b2eaea6c266e32c0697a9e2a898bdf 5c4fb3e5ac34f0e549bf2c58581a3811")))
|
||||
key, hpKey, iv := computeInitialKeyAndIV(serverSecret)
|
||||
key, iv := computeInitialKeyAndIV(serverSecret)
|
||||
Expect(key).To(Equal(split("9a8be902a9bdd91d16064ca118045fb4")))
|
||||
Expect(iv).To(Equal(split("0a82086d32205ba22241d8dc")))
|
||||
Expect(hpKey).To(Equal(split("94b9452d2b3c7c7f6da7fdd8593537fd")))
|
||||
})
|
||||
|
||||
It("encrypts the client's Initial", func() {
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
@@ -11,13 +9,6 @@ import (
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
type cipherSuite interface {
|
||||
Hash() crypto.Hash
|
||||
KeyLen() int
|
||||
IVLen() int
|
||||
AEAD(key, nonce []byte) cipher.AEAD
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
@@ -140,3 +131,16 @@ func tlsConfigToQtlsConfig(
|
||||
ReceivedExtensions: extHandler.ReceivedExtensions,
|
||||
}
|
||||
}
|
||||
|
||||
func cipherSuiteName(id uint16) string {
|
||||
switch id {
|
||||
case qtls.TLS_AES_128_GCM_SHA256:
|
||||
return "TLS_AES_128_GCM_SHA256"
|
||||
case qtls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
return "TLS_CHACHA20_POLY1305_SHA256"
|
||||
case qtls.TLS_AES_256_GCM_SHA384:
|
||||
return "TLS_AES_256_GCM_SHA384"
|
||||
default:
|
||||
return "unknown cipher suite"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func setKeyUpdateInterval() {
|
||||
}
|
||||
|
||||
type updatableAEAD struct {
|
||||
suite cipherSuite
|
||||
suite *qtls.CipherSuiteTLS13
|
||||
|
||||
keyPhase protocol.KeyPhase
|
||||
largestAcked protocol.PacketNumber
|
||||
@@ -67,8 +67,8 @@ type updatableAEAD struct {
|
||||
nextRcvTrafficSecret []byte
|
||||
nextSendTrafficSecret []byte
|
||||
|
||||
hpDecrypter cipher.Block
|
||||
hpEncrypter cipher.Block
|
||||
headerDecrypter headerProtector
|
||||
headerEncrypter headerProtector
|
||||
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
@@ -76,7 +76,6 @@ type updatableAEAD struct {
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
hpMask []byte
|
||||
}
|
||||
|
||||
var _ ShortHeaderOpener = &updatableAEAD{}
|
||||
@@ -104,8 +103,8 @@ func (a *updatableAEAD) rollKeys(now time.Time) {
|
||||
a.rcvAEAD = a.nextRcvAEAD
|
||||
a.sendAEAD = a.nextSendAEAD
|
||||
|
||||
a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash(), a.nextRcvTrafficSecret)
|
||||
a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash(), a.nextSendTrafficSecret)
|
||||
a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret)
|
||||
a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret)
|
||||
a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret)
|
||||
a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret)
|
||||
}
|
||||
@@ -116,33 +115,31 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte
|
||||
|
||||
// For the client, this function is called before SetWriteKey.
|
||||
// For the server, this function is called after SetWriteKey.
|
||||
func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) {
|
||||
func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||
a.rcvAEAD = createAEAD(suite, trafficSecret)
|
||||
a.hpDecrypter = createHeaderProtector(suite, trafficSecret)
|
||||
a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false)
|
||||
if a.suite == nil {
|
||||
a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize())
|
||||
a.hpMask = make([]byte, a.hpDecrypter.BlockSize())
|
||||
a.aeadOverhead = a.rcvAEAD.Overhead()
|
||||
a.suite = suite
|
||||
}
|
||||
|
||||
a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash(), trafficSecret)
|
||||
a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
|
||||
a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret)
|
||||
}
|
||||
|
||||
// For the client, this function is called after SetReadKey.
|
||||
// For the server, this function is called before SetWriteKey.
|
||||
func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
|
||||
func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||
a.sendAEAD = createAEAD(suite, trafficSecret)
|
||||
a.hpEncrypter = createHeaderProtector(suite, trafficSecret)
|
||||
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false)
|
||||
if a.suite == nil {
|
||||
a.nonceBuf = make([]byte, a.sendAEAD.NonceSize())
|
||||
a.hpMask = make([]byte, a.hpEncrypter.BlockSize())
|
||||
a.aeadOverhead = a.sendAEAD.Overhead()
|
||||
a.suite = suite
|
||||
}
|
||||
|
||||
a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash(), trafficSecret)
|
||||
a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
|
||||
a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret)
|
||||
}
|
||||
|
||||
@@ -245,24 +242,10 @@ func (a *updatableAEAD) Overhead() int {
|
||||
return a.aeadOverhead
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
if len(sample) != len(a.hpMask) {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
a.hpEncrypter.Encrypt(a.hpMask, sample)
|
||||
*firstByte ^= a.hpMask[0] & 0x1f
|
||||
for i := range pnBytes {
|
||||
pnBytes[i] ^= a.hpMask[i+1]
|
||||
}
|
||||
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
if len(sample) != len(a.hpMask) {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
a.hpDecrypter.Encrypt(a.hpMask, sample)
|
||||
*firstByte ^= a.hpMask[0] & 0x1f
|
||||
for i := range pnBytes {
|
||||
pnBytes[i] ^= a.hpMask[i+1]
|
||||
}
|
||||
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -15,252 +13,244 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockCipherSuite struct{}
|
||||
|
||||
var _ cipherSuite = &mockCipherSuite{}
|
||||
|
||||
func (c *mockCipherSuite) Hash() crypto.Hash { return crypto.SHA256 }
|
||||
func (c *mockCipherSuite) KeyLen() int { return 16 }
|
||||
func (c *mockCipherSuite) IVLen() int { return 12 }
|
||||
func (c *mockCipherSuite) AEAD(key, _ []byte) cipher.AEAD {
|
||||
block, err := aes.NewCipher(key)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return gcm
|
||||
}
|
||||
|
||||
var _ = Describe("Updatable AEAD", func() {
|
||||
getPeers := func(rttStats *congestion.RTTStats) (client, server *updatableAEAD) {
|
||||
trafficSecret1 := make([]byte, 16)
|
||||
trafficSecret2 := make([]byte, 16)
|
||||
rand.Read(trafficSecret1)
|
||||
rand.Read(trafficSecret2)
|
||||
for i := range cipherSuites {
|
||||
cs := cipherSuites[i]
|
||||
|
||||
client = newUpdatableAEAD(rttStats, utils.DefaultLogger)
|
||||
server = newUpdatableAEAD(rttStats, utils.DefaultLogger)
|
||||
client.SetReadKey(&mockCipherSuite{}, trafficSecret2)
|
||||
client.SetWriteKey(&mockCipherSuite{}, trafficSecret1)
|
||||
server.SetReadKey(&mockCipherSuite{}, trafficSecret1)
|
||||
server.SetWriteKey(&mockCipherSuite{}, trafficSecret2)
|
||||
return
|
||||
}
|
||||
Context(fmt.Sprintf("using %s", cipherSuiteName(cs.ID)), func() {
|
||||
|
||||
Context("header protection", func() {
|
||||
It("encrypts and decrypts the header", func() {
|
||||
server, client := getPeers(&congestion.RTTStats{})
|
||||
var lastFiveBitsDifferent int
|
||||
for i := 0; i < 100; i++ {
|
||||
sample := make([]byte, 16)
|
||||
rand.Read(sample)
|
||||
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
|
||||
client.EncryptHeader(sample, &header[0], header[9:13])
|
||||
if header[0]&0x1f != 0xb5&0x1f {
|
||||
lastFiveBitsDifferent++
|
||||
}
|
||||
Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0)))
|
||||
Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
|
||||
Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
|
||||
server.DecryptHeader(sample, &header[0], header[9:13])
|
||||
Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}))
|
||||
getPeers := func(rttStats *congestion.RTTStats) (client, server *updatableAEAD) {
|
||||
trafficSecret1 := make([]byte, 16)
|
||||
trafficSecret2 := make([]byte, 16)
|
||||
rand.Read(trafficSecret1)
|
||||
rand.Read(trafficSecret2)
|
||||
|
||||
client = newUpdatableAEAD(rttStats, utils.DefaultLogger)
|
||||
server = newUpdatableAEAD(rttStats, utils.DefaultLogger)
|
||||
client.SetReadKey(cs, trafficSecret2)
|
||||
client.SetWriteKey(cs, trafficSecret1)
|
||||
server.SetReadKey(cs, trafficSecret1)
|
||||
server.SetWriteKey(cs, trafficSecret2)
|
||||
return
|
||||
}
|
||||
Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75))
|
||||
})
|
||||
})
|
||||
|
||||
Context("message encryption", func() {
|
||||
var msg, ad []byte
|
||||
var server, client *updatableAEAD
|
||||
var rttStats *congestion.RTTStats
|
||||
|
||||
BeforeEach(func() {
|
||||
rttStats = &congestion.RTTStats{}
|
||||
server, client = getPeers(rttStats)
|
||||
msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||
ad = []byte("Donec in velit neque.")
|
||||
})
|
||||
|
||||
It("encrypts and decrypts a message", func() {
|
||||
encrypted := server.Seal(nil, msg, 0x1337, ad)
|
||||
opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(opened).To(Equal(msg))
|
||||
})
|
||||
|
||||
It("fails to open a message if the associated data is not the same", func() {
|
||||
encrypted := client.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad"))
|
||||
Expect(err).To(MatchError(ErrDecryptionFailed))
|
||||
})
|
||||
|
||||
It("fails to open a message if the packet number is not the same", func() {
|
||||
encrypted := server.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).To(MatchError(ErrDecryptionFailed))
|
||||
})
|
||||
|
||||
Context("key updates", func() {
|
||||
Context("receiving key updates", func() {
|
||||
It("updates keys", func() {
|
||||
now := time.Now()
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
encrypted0 := server.Seal(nil, msg, 0x1337, ad)
|
||||
server.rollKeys(now)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
encrypted1 := server.Seal(nil, msg, 0x1337, ad)
|
||||
Expect(encrypted0).ToNot(Equal(encrypted1))
|
||||
// expect opening to fail. The client didn't roll keys yet
|
||||
_, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).To(MatchError(ErrDecryptionFailed))
|
||||
client.rollKeys(now)
|
||||
decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decrypted).To(Equal(msg))
|
||||
})
|
||||
|
||||
It("updates the keys when receiving a packet with the next key phase", func() {
|
||||
now := time.Now()
|
||||
// receive the first packet at key phase zero
|
||||
encrypted0 := client.Seal(nil, msg, 0x42, ad)
|
||||
decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decrypted).To(Equal(msg))
|
||||
// send one packet at key phase zero
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
_ = server.Seal(nil, msg, 0x1, ad)
|
||||
// now received a message at key phase one
|
||||
client.rollKeys(now)
|
||||
encrypted1 := client.Seal(nil, msg, 0x43, ad)
|
||||
decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decrypted).To(Equal(msg))
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
})
|
||||
|
||||
It("opens a reordered packet with the old keys after an update", func() {
|
||||
now := time.Now()
|
||||
encrypted01 := client.Seal(nil, msg, 0x42, ad)
|
||||
encrypted02 := client.Seal(nil, msg, 0x43, ad)
|
||||
// receive the first packet with key phase 0
|
||||
_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// send one packet at key phase zero
|
||||
_ = server.Seal(nil, msg, 0x1, ad)
|
||||
// now receive a packet with key phase 1
|
||||
client.rollKeys(now)
|
||||
encrypted1 := client.Seal(nil, msg, 0x44, ad)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
// now receive a reordered packet with key phase 0
|
||||
decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decrypted).To(Equal(msg))
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
})
|
||||
|
||||
It("drops keys 3 PTOs after a key update", func() {
|
||||
now := time.Now()
|
||||
rttStats.UpdateRTT(10*time.Millisecond, 0, now)
|
||||
pto := rttStats.PTO()
|
||||
encrypted01 := client.Seal(nil, msg, 0x42, ad)
|
||||
encrypted02 := client.Seal(nil, msg, 0x43, ad)
|
||||
// receive the first packet with key phase 0
|
||||
_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// send one packet at key phase zero
|
||||
_ = server.Seal(nil, msg, 0x1, ad)
|
||||
// now receive a packet with key phase 1
|
||||
client.rollKeys(now)
|
||||
encrypted1 := client.Seal(nil, msg, 0x44, ad)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
// now receive a reordered packet with key phase 0
|
||||
_, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).To(MatchError(ErrKeysDropped))
|
||||
})
|
||||
|
||||
It("errors when the peer starts with key phase 1", func() {
|
||||
client.rollKeys(time.Now())
|
||||
encrypted := client.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase"))
|
||||
})
|
||||
|
||||
It("errors when the peer updates keys too frequently", func() {
|
||||
// receive the first packet at key phase zero
|
||||
encrypted0 := client.Seal(nil, msg, 0x42, ad)
|
||||
_, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// now receive a packet at key phase one, before having sent any packets
|
||||
client.rollKeys(time.Now())
|
||||
encrypted1 := client.Seal(nil, msg, 0x42, ad)
|
||||
_, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly"))
|
||||
Context("header protection", func() {
|
||||
It("encrypts and decrypts the header", func() {
|
||||
server, client := getPeers(&congestion.RTTStats{})
|
||||
var lastFiveBitsDifferent int
|
||||
for i := 0; i < 100; i++ {
|
||||
sample := make([]byte, 16)
|
||||
rand.Read(sample)
|
||||
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
|
||||
client.EncryptHeader(sample, &header[0], header[9:13])
|
||||
if header[0]&0x1f != 0xb5&0x1f {
|
||||
lastFiveBitsDifferent++
|
||||
}
|
||||
Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0)))
|
||||
Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
|
||||
Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
|
||||
server.DecryptHeader(sample, &header[0], header[9:13])
|
||||
Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}))
|
||||
}
|
||||
Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75))
|
||||
})
|
||||
})
|
||||
|
||||
Context("initiating key updates", func() {
|
||||
const keyUpdateInterval = 20
|
||||
Context("message encryption", func() {
|
||||
var msg, ad []byte
|
||||
var server, client *updatableAEAD
|
||||
var rttStats *congestion.RTTStats
|
||||
|
||||
BeforeEach(func() {
|
||||
Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval))
|
||||
server.keyUpdateInterval = keyUpdateInterval
|
||||
rttStats = &congestion.RTTStats{}
|
||||
server, client = getPeers(rttStats)
|
||||
msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||
ad = []byte("Donec in velit neque.")
|
||||
})
|
||||
|
||||
It("initiates a key update after sealing the maximum number of packets", func() {
|
||||
for i := 0; i < keyUpdateInterval; i++ {
|
||||
pn := protocol.PacketNumber(i)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
server.Seal(nil, msg, pn, ad)
|
||||
}
|
||||
// no update allowed before receiving an acknowledgement for the current key phase
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
server.SetLargestAcked(0)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
It("encrypts and decrypts a message", func() {
|
||||
encrypted := server.Seal(nil, msg, 0x1337, ad)
|
||||
opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(opened).To(Equal(msg))
|
||||
})
|
||||
|
||||
It("initiates a key update after opening the maximum number of packets", func() {
|
||||
for i := 0; i < keyUpdateInterval; i++ {
|
||||
pn := protocol.PacketNumber(i)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
encrypted := client.Seal(nil, msg, pn, ad)
|
||||
_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
// no update allowed before receiving an acknowledgement for the current key phase
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
server.Seal(nil, msg, 1, ad)
|
||||
server.SetLargestAcked(1)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
})
|
||||
})
|
||||
|
||||
Context("reading the key update env", func() {
|
||||
AfterEach(func() {
|
||||
os.Setenv(keyUpdateEnv, "")
|
||||
setKeyUpdateInterval()
|
||||
It("fails to open a message if the associated data is not the same", func() {
|
||||
encrypted := client.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad"))
|
||||
Expect(err).To(MatchError(ErrDecryptionFailed))
|
||||
})
|
||||
|
||||
It("uses the default value if the env is not set", func() {
|
||||
setKeyUpdateInterval()
|
||||
Expect(keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval))
|
||||
It("fails to open a message if the packet number is not the same", func() {
|
||||
encrypted := server.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).To(MatchError(ErrDecryptionFailed))
|
||||
})
|
||||
|
||||
It("uses the env", func() {
|
||||
os.Setenv(keyUpdateEnv, "1337")
|
||||
setKeyUpdateInterval()
|
||||
Expect(keyUpdateInterval).To(BeEquivalentTo(1337))
|
||||
})
|
||||
Context("key updates", func() {
|
||||
Context("receiving key updates", func() {
|
||||
It("updates keys", func() {
|
||||
now := time.Now()
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
encrypted0 := server.Seal(nil, msg, 0x1337, ad)
|
||||
server.rollKeys(now)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
encrypted1 := server.Seal(nil, msg, 0x1337, ad)
|
||||
Expect(encrypted0).ToNot(Equal(encrypted1))
|
||||
// expect opening to fail. The client didn't roll keys yet
|
||||
_, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).To(MatchError(ErrDecryptionFailed))
|
||||
client.rollKeys(now)
|
||||
decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decrypted).To(Equal(msg))
|
||||
})
|
||||
|
||||
It("panics when it can't parse the env", func() {
|
||||
os.Setenv(keyUpdateEnv, "foobar")
|
||||
Expect(setKeyUpdateInterval).To(Panic())
|
||||
It("updates the keys when receiving a packet with the next key phase", func() {
|
||||
now := time.Now()
|
||||
// receive the first packet at key phase zero
|
||||
encrypted0 := client.Seal(nil, msg, 0x42, ad)
|
||||
decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decrypted).To(Equal(msg))
|
||||
// send one packet at key phase zero
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
_ = server.Seal(nil, msg, 0x1, ad)
|
||||
// now received a message at key phase one
|
||||
client.rollKeys(now)
|
||||
encrypted1 := client.Seal(nil, msg, 0x43, ad)
|
||||
decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decrypted).To(Equal(msg))
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
})
|
||||
|
||||
It("opens a reordered packet with the old keys after an update", func() {
|
||||
now := time.Now()
|
||||
encrypted01 := client.Seal(nil, msg, 0x42, ad)
|
||||
encrypted02 := client.Seal(nil, msg, 0x43, ad)
|
||||
// receive the first packet with key phase 0
|
||||
_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// send one packet at key phase zero
|
||||
_ = server.Seal(nil, msg, 0x1, ad)
|
||||
// now receive a packet with key phase 1
|
||||
client.rollKeys(now)
|
||||
encrypted1 := client.Seal(nil, msg, 0x44, ad)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
// now receive a reordered packet with key phase 0
|
||||
decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decrypted).To(Equal(msg))
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
})
|
||||
|
||||
It("drops keys 3 PTOs after a key update", func() {
|
||||
now := time.Now()
|
||||
rttStats.UpdateRTT(10*time.Millisecond, 0, now)
|
||||
pto := rttStats.PTO()
|
||||
encrypted01 := client.Seal(nil, msg, 0x42, ad)
|
||||
encrypted02 := client.Seal(nil, msg, 0x43, ad)
|
||||
// receive the first packet with key phase 0
|
||||
_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// send one packet at key phase zero
|
||||
_ = server.Seal(nil, msg, 0x1, ad)
|
||||
// now receive a packet with key phase 1
|
||||
client.rollKeys(now)
|
||||
encrypted1 := client.Seal(nil, msg, 0x44, ad)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
// now receive a reordered packet with key phase 0
|
||||
_, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).To(MatchError(ErrKeysDropped))
|
||||
})
|
||||
|
||||
It("errors when the peer starts with key phase 1", func() {
|
||||
client.rollKeys(time.Now())
|
||||
encrypted := client.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase"))
|
||||
})
|
||||
|
||||
It("errors when the peer updates keys too frequently", func() {
|
||||
// receive the first packet at key phase zero
|
||||
encrypted0 := client.Seal(nil, msg, 0x42, ad)
|
||||
_, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// now receive a packet at key phase one, before having sent any packets
|
||||
client.rollKeys(time.Now())
|
||||
encrypted1 := client.Seal(nil, msg, 0x42, ad)
|
||||
_, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("initiating key updates", func() {
|
||||
const keyUpdateInterval = 20
|
||||
|
||||
BeforeEach(func() {
|
||||
Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval))
|
||||
server.keyUpdateInterval = keyUpdateInterval
|
||||
})
|
||||
|
||||
It("initiates a key update after sealing the maximum number of packets", func() {
|
||||
for i := 0; i < keyUpdateInterval; i++ {
|
||||
pn := protocol.PacketNumber(i)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
server.Seal(nil, msg, pn, ad)
|
||||
}
|
||||
// no update allowed before receiving an acknowledgement for the current key phase
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
server.SetLargestAcked(0)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
})
|
||||
|
||||
It("initiates a key update after opening the maximum number of packets", func() {
|
||||
for i := 0; i < keyUpdateInterval; i++ {
|
||||
pn := protocol.PacketNumber(i)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
encrypted := client.Seal(nil, msg, pn, ad)
|
||||
_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
// no update allowed before receiving an acknowledgement for the current key phase
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
server.Seal(nil, msg, 1, ad)
|
||||
server.SetLargestAcked(1)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
})
|
||||
})
|
||||
|
||||
Context("reading the key update env", func() {
|
||||
AfterEach(func() {
|
||||
os.Setenv(keyUpdateEnv, "")
|
||||
setKeyUpdateInterval()
|
||||
})
|
||||
|
||||
It("uses the default value if the env is not set", func() {
|
||||
setKeyUpdateInterval()
|
||||
Expect(keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval))
|
||||
})
|
||||
|
||||
It("uses the env", func() {
|
||||
os.Setenv(keyUpdateEnv, "1337")
|
||||
setKeyUpdateInterval()
|
||||
Expect(keyUpdateInterval).To(BeEquivalentTo(1337))
|
||||
})
|
||||
|
||||
It("panics when it can't parse the env", func() {
|
||||
os.Setenv(keyUpdateEnv, "foobar")
|
||||
Expect(setKeyUpdateInterval).To(Panic())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user