From 63c079e234f11f0174152dfbb00b30208a1f77c7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 6 Sep 2019 15:45:41 +0700 Subject: [PATCH 1/4] move header protection to a separate struct --- internal/handshake/aead.go | 57 ++++++----------------- internal/handshake/aead_test.go | 3 +- internal/handshake/crypto_setup.go | 4 +- internal/handshake/header_protector.go | 62 ++++++++++++++++++++++++++ internal/handshake/initial_aead.go | 8 ++-- internal/handshake/updatable_aead.go | 33 ++++---------- 6 files changed, 93 insertions(+), 74 deletions(-) create mode 100644 internal/handshake/header_protector.go diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index d5d71c9b..00043113 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -1,32 +1,28 @@ package handshake import ( - "crypto/aes" "crypto/cipher" "encoding/binary" - "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/marten-seemann/qtls" ) type sealer struct { - aead cipher.AEAD - hpEncrypter cipher.Block + aead cipher.AEAD + headerProtector headerProtector // use a single slice to avoid allocations nonceBuf []byte - hpMask []byte } var _ LongHeaderSealer = &sealer{} -func newLongHeaderSealer(aead cipher.AEAD, hpEncrypter cipher.Block) LongHeaderSealer { +func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer { return &sealer{ - aead: aead, - nonceBuf: make([]byte, aead.NonceSize()), - hpEncrypter: hpEncrypter, - hpMask: make([]byte, hpEncrypter.BlockSize()), + aead: aead, + headerProtector: headerProtector, + nonceBuf: make([]byte, aead.NonceSize()), } } @@ -38,14 +34,7 @@ func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []by } func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - if len(sample) != s.hpEncrypter.BlockSize() { - panic("invalid sample size") - } - s.hpEncrypter.Encrypt(s.hpMask, sample) - *firstByte ^= s.hpMask[0] & 0xf - for i := range pnBytes { - pnBytes[i] ^= s.hpMask[i+1] - } + s.headerProtector.EncryptHeader(sample, firstByte, pnBytes) } func (s *sealer) Overhead() int { @@ -53,22 +42,20 @@ func (s *sealer) Overhead() int { } type longHeaderOpener struct { - aead cipher.AEAD - pnDecrypter cipher.Block + aead cipher.AEAD + headerProtector headerProtector // use a single slice to avoid allocations nonceBuf []byte - hpMask []byte } var _ LongHeaderOpener = &longHeaderOpener{} -func newLongHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) LongHeaderOpener { +func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener { return &longHeaderOpener{ - aead: aead, - nonceBuf: make([]byte, aead.NonceSize()), - pnDecrypter: pnDecrypter, - hpMask: make([]byte, pnDecrypter.BlockSize()), + aead: aead, + headerProtector: headerProtector, + nonceBuf: make([]byte, aead.NonceSize()), } } @@ -84,14 +71,7 @@ func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad [] } func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - if len(sample) != o.pnDecrypter.BlockSize() { - panic("invalid sample size") - } - o.pnDecrypter.Encrypt(o.hpMask, sample) - *firstByte ^= o.hpMask[0] & 0xf - for i := range pnBytes { - pnBytes[i] ^= o.hpMask[i+1] - } + o.headerProtector.DecryptHeader(sample, firstByte, pnBytes) } func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD { @@ -99,12 +79,3 @@ func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD { iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen()) return suite.AEAD(key, iv) } - -func createHeaderProtector(suite cipherSuite, trafficSecret []byte) cipher.Block { - hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen()) - hp, err := aes.NewCipher(hpKey) - if err != nil { - panic(fmt.Sprintf("error creating new AES cipher: %s", err)) - } - return hp -} diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index 7498fb5a..15391472 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -24,7 +24,8 @@ var _ = Describe("AEAD", func() { iv := make([]byte, 12) rand.Read(iv) - return newLongHeaderSealer(aead, hpBlock), newLongHeaderOpener(aead, hpBlock) + return newLongHeaderSealer(aead, newAESHeaderProtector(hpBlock, true)), + newLongHeaderOpener(aead, newAESHeaderProtector(hpBlock, true)) } Context("message encryption", func() { diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index c0882381..85b6e1f1 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -498,7 +498,7 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph h.readEncLevel = protocol.EncryptionHandshake h.handshakeOpener = newLongHeaderOpener( createAEAD(suite, trafficSecret), - createHeaderProtector(suite, trafficSecret), + newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), true), ) h.logger.Debugf("Installed Handshake Read keys") case qtls.EncryptionApplication: @@ -520,7 +520,7 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip h.writeEncLevel = protocol.EncryptionHandshake h.handshakeSealer = newLongHeaderSealer( createAEAD(suite, trafficSecret), - createHeaderProtector(suite, trafficSecret), + newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), true), ) h.logger.Debugf("Installed Handshake Write keys") case qtls.EncryptionApplication: diff --git a/internal/handshake/header_protector.go b/internal/handshake/header_protector.go new file mode 100644 index 00000000..b6a7b7ea --- /dev/null +++ b/internal/handshake/header_protector.go @@ -0,0 +1,62 @@ +package handshake + +import ( + "crypto/aes" + "crypto/cipher" + "fmt" + + "github.com/marten-seemann/qtls" +) + +type headerProtector interface { + EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) + DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) +} + +func createAESHeaderProtector(suite cipherSuite, trafficSecret []byte) cipher.Block { + hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen()) + hp, err := aes.NewCipher(hpKey) + if err != nil { + panic(fmt.Sprintf("error creating new AES cipher: %s", err)) + } + return hp +} + +type aesHeaderProtector struct { + mask []byte + block cipher.Block + isLongHeader bool +} + +var _ headerProtector = &aesHeaderProtector{} + +func newAESHeaderProtector(block cipher.Block, isLongHeader bool) headerProtector { + return &aesHeaderProtector{ + block: block, + mask: make([]byte, block.BlockSize()), + isLongHeader: isLongHeader, + } +} + +func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { + if len(sample) != len(p.mask) { + panic("invalid sample size") + } + p.block.Encrypt(p.mask, sample) + if p.isLongHeader { + *firstByte ^= p.mask[0] & 0xf + } else { + *firstByte ^= p.mask[0] & 0x1f + } + for i := range hdrBytes { + hdrBytes[i] ^= p.mask[i+1] + } +} diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index da794c5c..eb90122b 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -25,16 +25,18 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo otherKey, otherHPKey, otherIV := computeInitialKeyAndIV(otherSecret) encrypter := qtls.AEADAESGCMTLS13(myKey, myIV) - hpEncrypter, err := aes.NewCipher(myHPKey) + encrypterBlock, err := aes.NewCipher(myHPKey) if err != nil { return nil, nil, err } decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV) - hpDecrypter, err := aes.NewCipher(otherHPKey) + decrypterBlock, err := aes.NewCipher(otherHPKey) if err != nil { return nil, nil, err } - return newLongHeaderSealer(encrypter, hpEncrypter), newLongHeaderOpener(decrypter, hpDecrypter), nil + return newLongHeaderSealer(encrypter, newAESHeaderProtector(encrypterBlock, true)), + newLongHeaderOpener(decrypter, newAESHeaderProtector(decrypterBlock, true)), + nil } func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) { diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index ff87aeeb..172c3e79 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -67,8 +67,8 @@ type updatableAEAD struct { nextRcvTrafficSecret []byte nextSendTrafficSecret []byte - hpDecrypter cipher.Block - hpEncrypter cipher.Block + headerDecrypter headerProtector + headerEncrypter headerProtector rttStats *congestion.RTTStats @@ -76,7 +76,6 @@ type updatableAEAD struct { // use a single slice to avoid allocations nonceBuf []byte - hpMask []byte } var _ ShortHeaderOpener = &updatableAEAD{} @@ -118,10 +117,9 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte // For the server, this function is called after SetWriteKey. func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) { a.rcvAEAD = createAEAD(suite, trafficSecret) - a.hpDecrypter = createHeaderProtector(suite, trafficSecret) + a.headerDecrypter = newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), false) if a.suite == nil { a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize()) - a.hpMask = make([]byte, a.hpDecrypter.BlockSize()) a.aeadOverhead = a.rcvAEAD.Overhead() a.suite = suite } @@ -134,10 +132,9 @@ func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) { // For the server, this function is called before SetWriteKey. func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) { a.sendAEAD = createAEAD(suite, trafficSecret) - a.hpEncrypter = createHeaderProtector(suite, trafficSecret) + a.headerEncrypter = newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), false) if a.suite == nil { a.nonceBuf = make([]byte, a.sendAEAD.NonceSize()) - a.hpMask = make([]byte, a.hpEncrypter.BlockSize()) a.aeadOverhead = a.sendAEAD.Overhead() a.suite = suite } @@ -245,24 +242,10 @@ func (a *updatableAEAD) Overhead() int { return a.aeadOverhead } -func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - if len(sample) != len(a.hpMask) { - panic("invalid sample size") - } - a.hpEncrypter.Encrypt(a.hpMask, sample) - *firstByte ^= a.hpMask[0] & 0x1f - for i := range pnBytes { - pnBytes[i] ^= a.hpMask[i+1] - } +func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes) } -func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - if len(sample) != len(a.hpMask) { - panic("invalid sample size") - } - a.hpDecrypter.Encrypt(a.hpMask, sample) - *firstByte ^= a.hpMask[0] & 0x1f - for i := range pnBytes { - pnBytes[i] ^= a.hpMask[i+1] - } +func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes) } From 33b74fca79cea3befbc62be48fcbe45dfceb2b9b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 6 Sep 2019 16:25:04 +0700 Subject: [PATCH 2/4] use the new qtls SetWriteKey and SetReadKey interface --- go.mod | 6 ++--- go.sum | 16 ++++++++----- internal/handshake/aead.go | 6 ++--- internal/handshake/aead_test.go | 8 ++----- internal/handshake/crypto_setup.go | 8 +++---- internal/handshake/handshake_suite_test.go | 13 +++++++++- internal/handshake/header_protector.go | 23 ++++++++++++------ internal/handshake/initial_aead.go | 28 ++++++++++------------ internal/handshake/initial_aead_test.go | 6 ++--- internal/handshake/qtls.go | 9 ------- internal/handshake/updatable_aead.go | 18 +++++++------- internal/handshake/updatable_aead_test.go | 26 ++++---------------- 12 files changed, 78 insertions(+), 89 deletions(-) diff --git a/go.mod b/go.mod index 3d6f0545..32a1df7f 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,9 @@ require ( github.com/golang/mock v1.2.0 github.com/golang/protobuf v1.3.0 github.com/marten-seemann/qpack v0.1.0 - github.com/marten-seemann/qtls v0.3.3 + github.com/marten-seemann/qtls v0.4.0 github.com/onsi/ginkgo v1.7.0 github.com/onsi/gomega v1.4.3 - golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25 - golang.org/x/net v0.0.0-20190228165749-92fc7df08ae7 + golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 + golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 ) diff --git a/go.sum b/go.sum index cfbfeb74..c534a87a 100644 --- a/go.sum +++ b/go.sum @@ -11,24 +11,28 @@ github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/marten-seemann/qpack v0.1.0 h1:/0M7lkda/6mus9B8u34Asqm8ZhHAAt9Ho0vniNuVSVg= github.com/marten-seemann/qpack v0.1.0/go.mod h1:LFt1NU/Ptjip0C2CPkhimBz5CGE3WGDAUWqna+CNTrI= -github.com/marten-seemann/qtls v0.3.3 h1:s6E9lHmjzoOqGnEw+7F+RREKEPq4lchp1Sl+Rj5Hqsc= -github.com/marten-seemann/qtls v0.3.3/go.mod h1:xzjG7avBwGGbdZ8dTGxlBnLArsVKLvwmjgmPuiQEcYk= +github.com/marten-seemann/qtls v0.4.0 h1:HM9ftULNeuhGiCliIfPKvp5VDJw6pvi/Ghq6PYf7B0E= +github.com/marten-seemann/qtls v0.4.0/go.mod h1:pxVXcHHw1pNIt8Qo0pwSYQEoZ8yYOOPXTCZLQQunvRc= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25 h1:jsG6UpNLt9iAsb0S2AGW28DveNzzgmbXR+ENoPjUeIU= -golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 h1:Gv7RPwsi3eZ2Fgewe3CBsuOebPwO27PoXzRpJPsvSSM= +golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190228165749-92fc7df08ae7 h1:Qe/u+eY379X4He4GBMFZYu3pmh1ML5yT1aL1ndNM1zQ= golang.org/x/net v0.0.0-20190228165749-92fc7df08ae7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190228124157-a34e9553db1e h1:ZytStCyV048ZqDsWHiYDdoI2Vd4msMcrDECFxS+tL9c= -golang.org/x/sys v0.0.0-20190228124157-a34e9553db1e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd h1:DBH9mDw0zluJT/R+nGuV3jWFWLFaHyYZWD4tOT+cjn0= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index 00043113..e0ef7041 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -74,8 +74,8 @@ func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes o.headerProtector.DecryptHeader(sample, firstByte, pnBytes) } -func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD { - key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic key", suite.KeyLen()) - iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen()) +func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) cipher.AEAD { + key := qtls.HkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic key", suite.KeyLen) + iv := qtls.HkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic iv", suite.IVLen()) return suite.AEAD(key, iv) } diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index 15391472..407da319 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -19,13 +19,9 @@ var _ = Describe("AEAD", func() { Expect(err).ToNot(HaveOccurred()) aead, err := cipher.NewGCM(block) Expect(err).ToNot(HaveOccurred()) - hpBlock, err := aes.NewCipher(hpKey) - Expect(err).ToNot(HaveOccurred()) - iv := make([]byte, 12) - rand.Read(iv) - return newLongHeaderSealer(aead, newAESHeaderProtector(hpBlock, true)), - newLongHeaderOpener(aead, newAESHeaderProtector(hpBlock, true)) + return newLongHeaderSealer(aead, newHeaderProtector(aesSuite, key, true)), + newLongHeaderOpener(aead, newAESHeaderProtector(aesSuite, key, true)) } Context("message encryption", func() { diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 85b6e1f1..c67108e2 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -491,14 +491,14 @@ func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) { return msg, nil } -func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuite, trafficSecret []byte) { +func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { h.mutex.Lock() switch encLevel { case qtls.EncryptionHandshake: h.readEncLevel = protocol.EncryptionHandshake h.handshakeOpener = newLongHeaderOpener( createAEAD(suite, trafficSecret), - newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), true), + newHeaderProtector(suite, trafficSecret, true), ) h.logger.Debugf("Installed Handshake Read keys") case qtls.EncryptionApplication: @@ -513,14 +513,14 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph h.receivedReadKey <- struct{}{} } -func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuite, trafficSecret []byte) { +func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { h.mutex.Lock() switch encLevel { case qtls.EncryptionHandshake: h.writeEncLevel = protocol.EncryptionHandshake h.handshakeSealer = newLongHeaderSealer( createAEAD(suite, trafficSecret), - newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), true), + newHeaderProtector(suite, trafficSecret, true), ) h.logger.Debugf("Installed Handshake Write keys") case qtls.EncryptionApplication: diff --git a/internal/handshake/handshake_suite_test.go b/internal/handshake/handshake_suite_test.go index 91f3d14b..10645e50 100644 --- a/internal/handshake/handshake_suite_test.go +++ b/internal/handshake/handshake_suite_test.go @@ -1,14 +1,18 @@ package handshake import ( + "crypto" + "github.com/golang/mock/gomock" + "github.com/marten-seemann/qtls" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "testing" ) -func TestQuicGo(t *testing.T) { +func TestHandshake(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Handshake Suite") } @@ -22,3 +26,10 @@ var _ = BeforeEach(func() { var _ = AfterEach(func() { mockCtrl.Finish() }) + +var aesSuite = &qtls.CipherSuiteTLS13{ + ID: qtls.TLS_AES_128_GCM_SHA256, + KeyLen: 16, + AEAD: qtls.AEADAESGCMTLS13, + Hash: crypto.SHA256, +} diff --git a/internal/handshake/header_protector.go b/internal/handshake/header_protector.go index b6a7b7ea..17e8fcd4 100644 --- a/internal/handshake/header_protector.go +++ b/internal/handshake/header_protector.go @@ -13,13 +13,17 @@ type headerProtector interface { DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) } -func createAESHeaderProtector(suite cipherSuite, trafficSecret []byte) cipher.Block { - hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen()) - hp, err := aes.NewCipher(hpKey) - if err != nil { - panic(fmt.Sprintf("error creating new AES cipher: %s", err)) +func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector { + switch suite.ID { + case qtls.TLS_AES_128_GCM_SHA256, qtls.TLS_AES_256_GCM_SHA384: + return newAESHeaderProtector(suite, trafficSecret, isLongHeader) + case qtls.TLS_CHACHA20_POLY1305_SHA256: + // TODO: implement ChaCha header protection + fallthrough + default: + panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID)) } - return hp + } type aesHeaderProtector struct { @@ -30,7 +34,12 @@ type aesHeaderProtector struct { var _ headerProtector = &aesHeaderProtector{} -func newAESHeaderProtector(block cipher.Block, isLongHeader bool) headerProtector { +func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector { + hpKey := qtls.HkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic hp", suite.KeyLen) + block, err := aes.NewCipher(hpKey) + if err != nil { + panic(fmt.Sprintf("error creating new AES cipher: %s", err)) + } return &aesHeaderProtector{ block: block, mask: make([]byte, block.BlockSize()), diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index eb90122b..d29b48af 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -2,7 +2,6 @@ package handshake import ( "crypto" - "crypto/aes" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/marten-seemann/qtls" @@ -10,6 +9,13 @@ import ( var quicVersion1Salt = []byte{0x7f, 0xbc, 0xdb, 0x0e, 0x7c, 0x66, 0xbb, 0xe9, 0x19, 0x3a, 0x96, 0xcd, 0x21, 0x51, 0x9e, 0xbd, 0x7a, 0x02, 0x64, 0x4a} +var initialSuite = &qtls.CipherSuiteTLS13{ + ID: qtls.TLS_AES_128_GCM_SHA256, + KeyLen: 16, + AEAD: qtls.AEADAESGCMTLS13, + Hash: crypto.SHA256, +} + // NewInitialAEAD creates a new AEAD for Initial encryption / decryption. func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, LongHeaderOpener, error) { clientSecret, serverSecret := computeSecrets(connID) @@ -21,21 +27,14 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo mySecret = serverSecret otherSecret = clientSecret } - myKey, myHPKey, myIV := computeInitialKeyAndIV(mySecret) - otherKey, otherHPKey, otherIV := computeInitialKeyAndIV(otherSecret) + myKey, myIV := computeInitialKeyAndIV(mySecret) + otherKey, otherIV := computeInitialKeyAndIV(otherSecret) encrypter := qtls.AEADAESGCMTLS13(myKey, myIV) - encrypterBlock, err := aes.NewCipher(myHPKey) - if err != nil { - return nil, nil, err - } decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV) - decrypterBlock, err := aes.NewCipher(otherHPKey) - if err != nil { - return nil, nil, err - } - return newLongHeaderSealer(encrypter, newAESHeaderProtector(encrypterBlock, true)), - newLongHeaderOpener(decrypter, newAESHeaderProtector(decrypterBlock, true)), + + return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true)), + newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true)), nil } @@ -46,9 +45,8 @@ func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret [] return } -func computeInitialKeyAndIV(secret []byte) (key, hpKey, iv []byte) { +func computeInitialKeyAndIV(secret []byte) (key, iv []byte) { key = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16) - hpKey = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic hp", 16) iv = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12) return } diff --git a/internal/handshake/initial_aead_test.go b/internal/handshake/initial_aead_test.go index 0cac773b..e978b860 100644 --- a/internal/handshake/initial_aead_test.go +++ b/internal/handshake/initial_aead_test.go @@ -42,19 +42,17 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { It("computes the client key and IV", func() { clientSecret, _ := computeSecrets(connID) Expect(clientSecret).To(Equal(split("8a3515a14ae3c31b9c2d6d5bc58538ca 5cd2baa119087143e60887428dcb52f6"))) - key, hpKey, iv := computeInitialKeyAndIV(clientSecret) + key, iv := computeInitialKeyAndIV(clientSecret) Expect(key).To(Equal(split("98b0d7e5e7a402c67c33f350fa65ea54"))) Expect(iv).To(Equal(split("19e94387805eb0b46c03a788"))) - Expect(hpKey).To(Equal(split("0edd982a6ac527f2eddcbb7348dea5d7"))) }) It("computes the server key and IV", func() { _, serverSecret := computeSecrets(connID) Expect(serverSecret).To(Equal(split("47b2eaea6c266e32c0697a9e2a898bdf 5c4fb3e5ac34f0e549bf2c58581a3811"))) - key, hpKey, iv := computeInitialKeyAndIV(serverSecret) + key, iv := computeInitialKeyAndIV(serverSecret) Expect(key).To(Equal(split("9a8be902a9bdd91d16064ca118045fb4"))) Expect(iv).To(Equal(split("0a82086d32205ba22241d8dc"))) - Expect(hpKey).To(Equal(split("94b9452d2b3c7c7f6da7fdd8593537fd"))) }) It("encrypts the client's Initial", func() { diff --git a/internal/handshake/qtls.go b/internal/handshake/qtls.go index 52d0723f..cd093142 100644 --- a/internal/handshake/qtls.go +++ b/internal/handshake/qtls.go @@ -1,8 +1,6 @@ package handshake import ( - "crypto" - "crypto/cipher" "crypto/tls" "net" "time" @@ -11,13 +9,6 @@ import ( "github.com/marten-seemann/qtls" ) -type cipherSuite interface { - Hash() crypto.Hash - KeyLen() int - IVLen() int - AEAD(key, nonce []byte) cipher.AEAD -} - type conn struct { remoteAddr net.Addr } diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 172c3e79..2444bc50 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -43,7 +43,7 @@ func setKeyUpdateInterval() { } type updatableAEAD struct { - suite cipherSuite + suite *qtls.CipherSuiteTLS13 keyPhase protocol.KeyPhase largestAcked protocol.PacketNumber @@ -103,8 +103,8 @@ func (a *updatableAEAD) rollKeys(now time.Time) { a.rcvAEAD = a.nextRcvAEAD a.sendAEAD = a.nextSendAEAD - a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash(), a.nextRcvTrafficSecret) - a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash(), a.nextSendTrafficSecret) + a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret) + a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret) a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret) a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret) } @@ -115,31 +115,31 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte // For the client, this function is called before SetWriteKey. // For the server, this function is called after SetWriteKey. -func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) { +func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { a.rcvAEAD = createAEAD(suite, trafficSecret) - a.headerDecrypter = newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), false) + a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false) if a.suite == nil { a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize()) a.aeadOverhead = a.rcvAEAD.Overhead() a.suite = suite } - a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash(), trafficSecret) + a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret) } // For the client, this function is called after SetReadKey. // For the server, this function is called before SetWriteKey. -func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) { +func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { a.sendAEAD = createAEAD(suite, trafficSecret) - a.headerEncrypter = newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), false) + a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false) if a.suite == nil { a.nonceBuf = make([]byte, a.sendAEAD.NonceSize()) a.aeadOverhead = a.sendAEAD.Overhead() a.suite = suite } - a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash(), trafficSecret) + a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret) } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 44df1295..80c1721f 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -1,9 +1,6 @@ package handshake import ( - "crypto" - "crypto/aes" - "crypto/cipher" "crypto/rand" "os" "time" @@ -15,21 +12,6 @@ import ( . "github.com/onsi/gomega" ) -type mockCipherSuite struct{} - -var _ cipherSuite = &mockCipherSuite{} - -func (c *mockCipherSuite) Hash() crypto.Hash { return crypto.SHA256 } -func (c *mockCipherSuite) KeyLen() int { return 16 } -func (c *mockCipherSuite) IVLen() int { return 12 } -func (c *mockCipherSuite) AEAD(key, _ []byte) cipher.AEAD { - block, err := aes.NewCipher(key) - Expect(err).ToNot(HaveOccurred()) - gcm, err := cipher.NewGCM(block) - Expect(err).ToNot(HaveOccurred()) - return gcm -} - var _ = Describe("Updatable AEAD", func() { getPeers := func(rttStats *congestion.RTTStats) (client, server *updatableAEAD) { trafficSecret1 := make([]byte, 16) @@ -39,10 +21,10 @@ var _ = Describe("Updatable AEAD", func() { client = newUpdatableAEAD(rttStats, utils.DefaultLogger) server = newUpdatableAEAD(rttStats, utils.DefaultLogger) - client.SetReadKey(&mockCipherSuite{}, trafficSecret2) - client.SetWriteKey(&mockCipherSuite{}, trafficSecret1) - server.SetReadKey(&mockCipherSuite{}, trafficSecret1) - server.SetWriteKey(&mockCipherSuite{}, trafficSecret2) + client.SetReadKey(aesSuite, trafficSecret2) + client.SetWriteKey(aesSuite, trafficSecret1) + server.SetReadKey(aesSuite, trafficSecret1) + server.SetWriteKey(aesSuite, trafficSecret2) return } From fa89ec345ab9bca8ea6e06c2150c7fc7c17b9562 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 7 Sep 2019 11:06:30 +0700 Subject: [PATCH 3/4] add support for ChaCha20 header protection --- go.mod | 2 + go.sum | 4 + internal/handshake/aead_test.go | 143 +++---- internal/handshake/handshake_suite_test.go | 51 ++- internal/handshake/header_protector.go | 52 ++- internal/handshake/updatable_aead_test.go | 431 +++++++++++---------- 6 files changed, 398 insertions(+), 285 deletions(-) diff --git a/go.mod b/go.mod index 32a1df7f..41eabadf 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,11 @@ module github.com/lucas-clemente/quic-go go 1.13 require ( + github.com/alangpierce/go-forceexport v0.0.0-20160317203124-8f1d6941cd75 github.com/cheekybits/genny v1.0.0 github.com/golang/mock v1.2.0 github.com/golang/protobuf v1.3.0 + github.com/marten-seemann/chacha20 v0.2.0 github.com/marten-seemann/qpack v0.1.0 github.com/marten-seemann/qtls v0.4.0 github.com/onsi/ginkgo v1.7.0 diff --git a/go.sum b/go.sum index c534a87a..c343aaa8 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/alangpierce/go-forceexport v0.0.0-20160317203124-8f1d6941cd75 h1:3ILjVyslFbc4jl1w5TWuvvslFD/nDfR2H8tVaMVLrEY= +github.com/alangpierce/go-forceexport v0.0.0-20160317203124-8f1d6941cd75/go.mod h1:uAXEEpARkRhCZfEvy/y0Jcc888f9tHCc1W7/UeEtreE= github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= @@ -9,6 +11,8 @@ github.com/golang/protobuf v1.3.0 h1:kbxbvI4Un1LUWKxufD+BiE6AEExYYgkQLQmLFqA1LFk github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/marten-seemann/chacha20 v0.2.0 h1:f40vqzzx+3GdOmzQoItkLX5WLvHgPgyYqFFIO5Gh4hQ= +github.com/marten-seemann/chacha20 v0.2.0/go.mod h1:HSdjFau7GzYRj+ahFNwsO3ouVJr1HFkWoEwNDb4TMtE= github.com/marten-seemann/qpack v0.1.0 h1:/0M7lkda/6mus9B8u34Asqm8ZhHAAt9Ho0vniNuVSVg= github.com/marten-seemann/qpack v0.1.0/go.mod h1:LFt1NU/Ptjip0C2CPkhimBz5CGE3WGDAUWqna+CNTrI= github.com/marten-seemann/qtls v0.4.0 h1:HM9ftULNeuhGiCliIfPKvp5VDJw6pvi/Ghq6PYf7B0E= diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index 407da319..c0e5f2a2 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -4,83 +4,92 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "fmt" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("AEAD", func() { - getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { - key := make([]byte, 16) - hpKey := make([]byte, 16) - rand.Read(key) - rand.Read(hpKey) - block, err := aes.NewCipher(key) - Expect(err).ToNot(HaveOccurred()) - aead, err := cipher.NewGCM(block) - Expect(err).ToNot(HaveOccurred()) + for i := range cipherSuites { + cs := cipherSuites[i] - return newLongHeaderSealer(aead, newHeaderProtector(aesSuite, key, true)), - newLongHeaderOpener(aead, newAESHeaderProtector(aesSuite, key, true)) - } + Context(fmt.Sprintf("using %s", cs.name), func() { + suite := cs.suite - Context("message encryption", func() { - msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - ad := []byte("Donec in velit neque.") + getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { + key := make([]byte, 16) + hpKey := make([]byte, 16) + rand.Read(key) + rand.Read(hpKey) + block, err := aes.NewCipher(key) + Expect(err).ToNot(HaveOccurred()) + aead, err := cipher.NewGCM(block) + Expect(err).ToNot(HaveOccurred()) - It("encrypts and decrypts a message", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - opened, err := opener.Open(nil, encrypted, 0x1337, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opened).To(Equal(msg)) - }) - - It("fails to open a message if the associated data is not the same", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("fails to open a message if the packet number is not the same", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x42, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - }) - - Context("header encryption", func() { - It("encrypts and encrypts the header", func() { - sealer, opener := getSealerAndOpener() - var lastFourBitsDifferent int - for i := 0; i < 100; i++ { - sample := make([]byte, 16) - rand.Read(sample) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sealer.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0xf != 0xb5&0xf { - lastFourBitsDifferent++ - } - Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + return newLongHeaderSealer(aead, newHeaderProtector(suite, key, true)), + newLongHeaderOpener(aead, newHeaderProtector(suite, key, true)) } - Expect(lastFourBitsDifferent).To(BeNumerically(">", 75)) - }) - It("fails to decrypt the header when using a different sample", func() { - sealer, opener := getSealerAndOpener() - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sample := make([]byte, 16) - rand.Read(sample) - sealer.EncryptHeader(sample, &header[0], header[9:13]) - rand.Read(sample) // use a different sample - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + Context("message encryption", func() { + msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad := []byte("Donec in velit neque.") + + It("encrypts and decrypts a message", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + opened, err := opener.Open(nil, encrypted, 0x1337, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) + }) + + It("fails to open a message if the associated data is not the same", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("fails to open a message if the packet number is not the same", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x42, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + }) + + Context("header encryption", func() { + It("encrypts and encrypts the header", func() { + sealer, opener := getSealerAndOpener() + var lastFourBitsDifferent int + for i := 0; i < 100; i++ { + sample := make([]byte, 16) + rand.Read(sample) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sealer.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0xf != 0xb5&0xf { + lastFourBitsDifferent++ + } + Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + Expect(lastFourBitsDifferent).To(BeNumerically(">", 75)) + }) + + It("fails to decrypt the header when using a different sample", func() { + sealer, opener := getSealerAndOpener() + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sample := make([]byte, 16) + rand.Read(sample) + sealer.EncryptHeader(sample, &header[0], header[9:13]) + rand.Read(sample) // use a different sample + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + }) + }) }) - }) + } }) diff --git a/internal/handshake/handshake_suite_test.go b/internal/handshake/handshake_suite_test.go index 10645e50..158460eb 100644 --- a/internal/handshake/handshake_suite_test.go +++ b/internal/handshake/handshake_suite_test.go @@ -2,7 +2,9 @@ package handshake import ( "crypto" + "crypto/cipher" + "github.com/alangpierce/go-forceexport" "github.com/golang/mock/gomock" "github.com/marten-seemann/qtls" @@ -27,9 +29,48 @@ var _ = AfterEach(func() { mockCtrl.Finish() }) -var aesSuite = &qtls.CipherSuiteTLS13{ - ID: qtls.TLS_AES_128_GCM_SHA256, - KeyLen: 16, - AEAD: qtls.AEADAESGCMTLS13, - Hash: crypto.SHA256, +var aeadChaCha20Poly1305 func(key, nonceMask []byte) cipher.AEAD + +var cipherSuites = []struct { + name string + suite *qtls.CipherSuiteTLS13 +}{ + { + name: "TLS_AES_128_GCM_SHA256", + suite: &qtls.CipherSuiteTLS13{ + ID: qtls.TLS_AES_128_GCM_SHA256, + KeyLen: 16, + AEAD: qtls.AEADAESGCMTLS13, + Hash: crypto.SHA256, + }, + }, + { + name: "TLS_AES_256_GCM_SHA384", + suite: &qtls.CipherSuiteTLS13{ + ID: qtls.TLS_AES_256_GCM_SHA384, + KeyLen: 32, + AEAD: qtls.AEADAESGCMTLS13, + Hash: crypto.SHA384, + }, + }, + { + name: "TLS_CHACHA20_POLY1305_SHA256", + suite: &qtls.CipherSuiteTLS13{ + ID: qtls.TLS_CHACHA20_POLY1305_SHA256, + KeyLen: 32, + AEAD: nil, // will be set by init + Hash: crypto.SHA256, + }, + }, +} + +func init() { + if err := forceexport.GetFunc(&aeadChaCha20Poly1305, "github.com/marten-seemann/qtls.aeadChaCha20Poly1305"); err != nil { + panic(err) + } + for _, s := range cipherSuites { + if s.suite.ID == qtls.TLS_CHACHA20_POLY1305_SHA256 { + s.suite.AEAD = aeadChaCha20Poly1305 + } + } } diff --git a/internal/handshake/header_protector.go b/internal/handshake/header_protector.go index 17e8fcd4..019d5703 100644 --- a/internal/handshake/header_protector.go +++ b/internal/handshake/header_protector.go @@ -5,6 +5,7 @@ import ( "crypto/cipher" "fmt" + "github.com/marten-seemann/chacha20" "github.com/marten-seemann/qtls" ) @@ -18,8 +19,7 @@ func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLo case qtls.TLS_AES_128_GCM_SHA256, qtls.TLS_AES_256_GCM_SHA384: return newAESHeaderProtector(suite, trafficSecret, isLongHeader) case qtls.TLS_CHACHA20_POLY1305_SHA256: - // TODO: implement ChaCha header protection - fallthrough + return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader) default: panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID)) } @@ -69,3 +69,51 @@ func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []by hdrBytes[i] ^= p.mask[i+1] } } + +type chachaHeaderProtector struct { + mask [5]byte + + key [32]byte + sampleBuf [16]byte + isLongHeader bool +} + +var _ headerProtector = &chachaHeaderProtector{} + +func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector { + hpKey := qtls.HkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic hp", suite.KeyLen) + + p := &chachaHeaderProtector{ + isLongHeader: isLongHeader, + } + copy(p.key[:], hpKey) + return p +} + +func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { + if len(sample) < len(p.mask) { + panic("invalid sample size") + } + for i := 0; i < 5; i++ { + p.mask[i] = 0 + } + copy(p.sampleBuf[:], sample) + chacha20.XORKeyStream(p.mask[:], p.mask[:], &p.sampleBuf, &p.key) + + if p.isLongHeader { + *firstByte ^= p.mask[0] & 0xf + } else { + *firstByte ^= p.mask[0] & 0x1f + } + for i := range hdrBytes { + hdrBytes[i] ^= p.mask[i+1] + } +} diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 80c1721f..ab28d28b 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -2,6 +2,7 @@ package handshake import ( "crypto/rand" + "fmt" "os" "time" @@ -13,236 +14,244 @@ import ( ) var _ = Describe("Updatable AEAD", func() { - getPeers := func(rttStats *congestion.RTTStats) (client, server *updatableAEAD) { - trafficSecret1 := make([]byte, 16) - trafficSecret2 := make([]byte, 16) - rand.Read(trafficSecret1) - rand.Read(trafficSecret2) + for i := range cipherSuites { + cs := cipherSuites[i] - client = newUpdatableAEAD(rttStats, utils.DefaultLogger) - server = newUpdatableAEAD(rttStats, utils.DefaultLogger) - client.SetReadKey(aesSuite, trafficSecret2) - client.SetWriteKey(aesSuite, trafficSecret1) - server.SetReadKey(aesSuite, trafficSecret1) - server.SetWriteKey(aesSuite, trafficSecret2) - return - } + Context(fmt.Sprintf("using %s", cs.name), func() { + suite := cs.suite - Context("header protection", func() { - It("encrypts and decrypts the header", func() { - server, client := getPeers(&congestion.RTTStats{}) - var lastFiveBitsDifferent int - for i := 0; i < 100; i++ { - sample := make([]byte, 16) - rand.Read(sample) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - client.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0x1f != 0xb5&0x1f { - lastFiveBitsDifferent++ - } - Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - server.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + getPeers := func(rttStats *congestion.RTTStats) (client, server *updatableAEAD) { + trafficSecret1 := make([]byte, 16) + trafficSecret2 := make([]byte, 16) + rand.Read(trafficSecret1) + rand.Read(trafficSecret2) + + client = newUpdatableAEAD(rttStats, utils.DefaultLogger) + server = newUpdatableAEAD(rttStats, utils.DefaultLogger) + client.SetReadKey(suite, trafficSecret2) + client.SetWriteKey(suite, trafficSecret1) + server.SetReadKey(suite, trafficSecret1) + server.SetWriteKey(suite, trafficSecret2) + return } - Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) - }) - }) - Context("message encryption", func() { - var msg, ad []byte - var server, client *updatableAEAD - var rttStats *congestion.RTTStats - - BeforeEach(func() { - rttStats = &congestion.RTTStats{} - server, client = getPeers(rttStats) - msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - ad = []byte("Donec in velit neque.") - }) - - It("encrypts and decrypts a message", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opened).To(Equal(msg)) - }) - - It("fails to open a message if the associated data is not the same", func() { - encrypted := client.Seal(nil, msg, 0x1337, ad) - _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("fails to open a message if the packet number is not the same", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - _, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - Context("key updates", func() { - Context("receiving key updates", func() { - It("updates keys", func() { - now := time.Now() - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - encrypted0 := server.Seal(nil, msg, 0x1337, ad) - server.rollKeys(now) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - encrypted1 := server.Seal(nil, msg, 0x1337, ad) - Expect(encrypted0).ToNot(Equal(encrypted1)) - // expect opening to fail. The client didn't roll keys yet - _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - client.rollKeys(now) - decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - }) - - It("updates the keys when receiving a packet with the next key phase", func() { - now := time.Now() - // receive the first packet at key phase zero - encrypted0 := client.Seal(nil, msg, 0x42, ad) - decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - // send one packet at key phase zero - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - _ = server.Seal(nil, msg, 0x1, ad) - // now received a message at key phase one - client.rollKeys(now) - encrypted1 := client.Seal(nil, msg, 0x43, ad) - decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("opens a reordered packet with the old keys after an update", func() { - now := time.Now() - encrypted01 := client.Seal(nil, msg, 0x42, ad) - encrypted02 := client.Seal(nil, msg, 0x43, ad) - // receive the first packet with key phase 0 - _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // send one packet at key phase zero - _ = server.Seal(nil, msg, 0x1, ad) - // now receive a packet with key phase 1 - client.rollKeys(now) - encrypted1 := client.Seal(nil, msg, 0x44, ad) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // now receive a reordered packet with key phase 0 - decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("drops keys 3 PTOs after a key update", func() { - now := time.Now() - rttStats.UpdateRTT(10*time.Millisecond, 0, now) - pto := rttStats.PTO() - encrypted01 := client.Seal(nil, msg, 0x42, ad) - encrypted02 := client.Seal(nil, msg, 0x43, ad) - // receive the first packet with key phase 0 - _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // send one packet at key phase zero - _ = server.Seal(nil, msg, 0x1, ad) - // now receive a packet with key phase 1 - client.rollKeys(now) - encrypted1 := client.Seal(nil, msg, 0x44, ad) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // now receive a reordered packet with key phase 0 - _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrKeysDropped)) - }) - - It("errors when the peer starts with key phase 1", func() { - client.rollKeys(time.Now()) - encrypted := client.Seal(nil, msg, 0x1337, ad) - _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase")) - }) - - It("errors when the peer updates keys too frequently", func() { - // receive the first packet at key phase zero - encrypted0 := client.Seal(nil, msg, 0x42, ad) - _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // now receive a packet at key phase one, before having sent any packets - client.rollKeys(time.Now()) - encrypted1 := client.Seal(nil, msg, 0x42, ad) - _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad) - Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly")) + Context("header protection", func() { + It("encrypts and decrypts the header", func() { + server, client := getPeers(&congestion.RTTStats{}) + var lastFiveBitsDifferent int + for i := 0; i < 100; i++ { + sample := make([]byte, 16) + rand.Read(sample) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + client.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0x1f != 0xb5&0x1f { + lastFiveBitsDifferent++ + } + Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + server.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) }) }) - Context("initiating key updates", func() { - const keyUpdateInterval = 20 + Context("message encryption", func() { + var msg, ad []byte + var server, client *updatableAEAD + var rttStats *congestion.RTTStats BeforeEach(func() { - Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) - server.keyUpdateInterval = keyUpdateInterval + rttStats = &congestion.RTTStats{} + server, client = getPeers(rttStats) + msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad = []byte("Donec in velit neque.") }) - It("initiates a key update after sealing the maximum number of packets", func() { - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.SetLargestAcked(0) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + It("encrypts and decrypts a message", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) }) - It("initiates a key update after opening the maximum number of packets", func() { - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - encrypted := client.Seal(nil, msg, pn, ad) - _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, 1, ad) - server.SetLargestAcked(1) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - }) - - Context("reading the key update env", func() { - AfterEach(func() { - os.Setenv(keyUpdateEnv, "") - setKeyUpdateInterval() + It("fails to open a message if the associated data is not the same", func() { + encrypted := client.Seal(nil, msg, 0x1337, ad) + _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) }) - It("uses the default value if the env is not set", func() { - setKeyUpdateInterval() - Expect(keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) + It("fails to open a message if the packet number is not the same", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) }) - It("uses the env", func() { - os.Setenv(keyUpdateEnv, "1337") - setKeyUpdateInterval() - Expect(keyUpdateInterval).To(BeEquivalentTo(1337)) - }) + Context("key updates", func() { + Context("receiving key updates", func() { + It("updates keys", func() { + now := time.Now() + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + encrypted0 := server.Seal(nil, msg, 0x1337, ad) + server.rollKeys(now) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + encrypted1 := server.Seal(nil, msg, 0x1337, ad) + Expect(encrypted0).ToNot(Equal(encrypted1)) + // expect opening to fail. The client didn't roll keys yet + _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + client.rollKeys(now) + decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + }) - It("panics when it can't parse the env", func() { - os.Setenv(keyUpdateEnv, "foobar") - Expect(setKeyUpdateInterval).To(Panic()) + It("updates the keys when receiving a packet with the next key phase", func() { + now := time.Now() + // receive the first packet at key phase zero + encrypted0 := client.Seal(nil, msg, 0x42, ad) + decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + // send one packet at key phase zero + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _ = server.Seal(nil, msg, 0x1, ad) + // now received a message at key phase one + client.rollKeys(now) + encrypted1 := client.Seal(nil, msg, 0x43, ad) + decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("opens a reordered packet with the old keys after an update", func() { + now := time.Now() + encrypted01 := client.Seal(nil, msg, 0x42, ad) + encrypted02 := client.Seal(nil, msg, 0x43, ad) + // receive the first packet with key phase 0 + _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // send one packet at key phase zero + _ = server.Seal(nil, msg, 0x1, ad) + // now receive a packet with key phase 1 + client.rollKeys(now) + encrypted1 := client.Seal(nil, msg, 0x44, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // now receive a reordered packet with key phase 0 + decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("drops keys 3 PTOs after a key update", func() { + now := time.Now() + rttStats.UpdateRTT(10*time.Millisecond, 0, now) + pto := rttStats.PTO() + encrypted01 := client.Seal(nil, msg, 0x42, ad) + encrypted02 := client.Seal(nil, msg, 0x43, ad) + // receive the first packet with key phase 0 + _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // send one packet at key phase zero + _ = server.Seal(nil, msg, 0x1, ad) + // now receive a packet with key phase 1 + client.rollKeys(now) + encrypted1 := client.Seal(nil, msg, 0x44, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // now receive a reordered packet with key phase 0 + _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrKeysDropped)) + }) + + It("errors when the peer starts with key phase 1", func() { + client.rollKeys(time.Now()) + encrypted := client.Seal(nil, msg, 0x1337, ad) + _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase")) + }) + + It("errors when the peer updates keys too frequently", func() { + // receive the first packet at key phase zero + encrypted0 := client.Seal(nil, msg, 0x42, ad) + _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // now receive a packet at key phase one, before having sent any packets + client.rollKeys(time.Now()) + encrypted1 := client.Seal(nil, msg, 0x42, ad) + _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad) + Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly")) + }) + }) + + Context("initiating key updates", func() { + const keyUpdateInterval = 20 + + BeforeEach(func() { + Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) + server.keyUpdateInterval = keyUpdateInterval + }) + + It("initiates a key update after sealing the maximum number of packets", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.SetLargestAcked(0) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("initiates a key update after opening the maximum number of packets", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + encrypted := client.Seal(nil, msg, pn, ad) + _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, 1, ad) + server.SetLargestAcked(1) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + }) + + Context("reading the key update env", func() { + AfterEach(func() { + os.Setenv(keyUpdateEnv, "") + setKeyUpdateInterval() + }) + + It("uses the default value if the env is not set", func() { + setKeyUpdateInterval() + Expect(keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) + }) + + It("uses the env", func() { + os.Setenv(keyUpdateEnv, "1337") + setKeyUpdateInterval() + Expect(keyUpdateInterval).To(BeEquivalentTo(1337)) + }) + + It("panics when it can't parse the env", func() { + os.Setenv(keyUpdateEnv, "foobar") + Expect(setKeyUpdateInterval).To(Panic()) + }) + }) }) }) }) - }) + } }) From de3e1a3de5d36e8ce038b0c8f325d52ffedf9933 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 7 Sep 2019 11:47:21 +0700 Subject: [PATCH 4/4] log the cipher suite --- internal/handshake/aead_test.go | 8 ++-- internal/handshake/crypto_setup.go | 8 ++-- internal/handshake/handshake_suite_test.go | 48 ++++++++-------------- internal/handshake/qtls.go | 13 ++++++ internal/handshake/updatable_aead_test.go | 11 +++-- 5 files changed, 43 insertions(+), 45 deletions(-) diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index c0e5f2a2..49082d2f 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -14,9 +14,7 @@ var _ = Describe("AEAD", func() { for i := range cipherSuites { cs := cipherSuites[i] - Context(fmt.Sprintf("using %s", cs.name), func() { - suite := cs.suite - + Context(fmt.Sprintf("using %s", cipherSuiteName(cs.ID)), func() { getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { key := make([]byte, 16) hpKey := make([]byte, 16) @@ -27,8 +25,8 @@ var _ = Describe("AEAD", func() { aead, err := cipher.NewGCM(block) Expect(err).ToNot(HaveOccurred()) - return newLongHeaderSealer(aead, newHeaderProtector(suite, key, true)), - newLongHeaderOpener(aead, newHeaderProtector(suite, key, true)) + return newLongHeaderSealer(aead, newHeaderProtector(cs, key, true)), + newLongHeaderOpener(aead, newHeaderProtector(cs, key, true)) } Context("message encryption", func() { diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index c67108e2..91f555a3 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -500,12 +500,12 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph createAEAD(suite, trafficSecret), newHeaderProtector(suite, trafficSecret, true), ) - h.logger.Debugf("Installed Handshake Read keys") + h.logger.Debugf("Installed Handshake Read keys (using %s)", cipherSuiteName(suite.ID)) case qtls.EncryptionApplication: h.readEncLevel = protocol.Encryption1RTT h.aead.SetReadKey(suite, trafficSecret) h.has1RTTOpener = true - h.logger.Debugf("Installed 1-RTT Read keys") + h.logger.Debugf("Installed 1-RTT Read keys (using %s)", cipherSuiteName(suite.ID)) default: panic("unexpected read encryption level") } @@ -522,12 +522,12 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip createAEAD(suite, trafficSecret), newHeaderProtector(suite, trafficSecret, true), ) - h.logger.Debugf("Installed Handshake Write keys") + h.logger.Debugf("Installed Handshake Write keys (using %s)", cipherSuiteName(suite.ID)) case qtls.EncryptionApplication: h.writeEncLevel = protocol.Encryption1RTT h.aead.SetWriteKey(suite, trafficSecret) h.has1RTTSealer = true - h.logger.Debugf("Installed 1-RTT Write keys") + h.logger.Debugf("Installed 1-RTT Write keys (using %s)", cipherSuiteName(suite.ID)) default: panic("unexpected write encryption level") } diff --git a/internal/handshake/handshake_suite_test.go b/internal/handshake/handshake_suite_test.go index 158460eb..20750b99 100644 --- a/internal/handshake/handshake_suite_test.go +++ b/internal/handshake/handshake_suite_test.go @@ -31,36 +31,24 @@ var _ = AfterEach(func() { var aeadChaCha20Poly1305 func(key, nonceMask []byte) cipher.AEAD -var cipherSuites = []struct { - name string - suite *qtls.CipherSuiteTLS13 -}{ - { - name: "TLS_AES_128_GCM_SHA256", - suite: &qtls.CipherSuiteTLS13{ - ID: qtls.TLS_AES_128_GCM_SHA256, - KeyLen: 16, - AEAD: qtls.AEADAESGCMTLS13, - Hash: crypto.SHA256, - }, +var cipherSuites = []*qtls.CipherSuiteTLS13{ + &qtls.CipherSuiteTLS13{ + ID: qtls.TLS_AES_128_GCM_SHA256, + KeyLen: 16, + AEAD: qtls.AEADAESGCMTLS13, + Hash: crypto.SHA256, }, - { - name: "TLS_AES_256_GCM_SHA384", - suite: &qtls.CipherSuiteTLS13{ - ID: qtls.TLS_AES_256_GCM_SHA384, - KeyLen: 32, - AEAD: qtls.AEADAESGCMTLS13, - Hash: crypto.SHA384, - }, + &qtls.CipherSuiteTLS13{ + ID: qtls.TLS_AES_256_GCM_SHA384, + KeyLen: 32, + AEAD: qtls.AEADAESGCMTLS13, + Hash: crypto.SHA384, }, - { - name: "TLS_CHACHA20_POLY1305_SHA256", - suite: &qtls.CipherSuiteTLS13{ - ID: qtls.TLS_CHACHA20_POLY1305_SHA256, - KeyLen: 32, - AEAD: nil, // will be set by init - Hash: crypto.SHA256, - }, + &qtls.CipherSuiteTLS13{ + ID: qtls.TLS_CHACHA20_POLY1305_SHA256, + KeyLen: 32, + AEAD: nil, // will be set by init + Hash: crypto.SHA256, }, } @@ -69,8 +57,8 @@ func init() { panic(err) } for _, s := range cipherSuites { - if s.suite.ID == qtls.TLS_CHACHA20_POLY1305_SHA256 { - s.suite.AEAD = aeadChaCha20Poly1305 + if s.ID == qtls.TLS_CHACHA20_POLY1305_SHA256 { + s.AEAD = aeadChaCha20Poly1305 } } } diff --git a/internal/handshake/qtls.go b/internal/handshake/qtls.go index cd093142..9ce5e655 100644 --- a/internal/handshake/qtls.go +++ b/internal/handshake/qtls.go @@ -131,3 +131,16 @@ func tlsConfigToQtlsConfig( ReceivedExtensions: extHandler.ReceivedExtensions, } } + +func cipherSuiteName(id uint16) string { + switch id { + case qtls.TLS_AES_128_GCM_SHA256: + return "TLS_AES_128_GCM_SHA256" + case qtls.TLS_CHACHA20_POLY1305_SHA256: + return "TLS_CHACHA20_POLY1305_SHA256" + case qtls.TLS_AES_256_GCM_SHA384: + return "TLS_AES_256_GCM_SHA384" + default: + return "unknown cipher suite" + } +} diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index ab28d28b..51000a44 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -17,8 +17,7 @@ var _ = Describe("Updatable AEAD", func() { for i := range cipherSuites { cs := cipherSuites[i] - Context(fmt.Sprintf("using %s", cs.name), func() { - suite := cs.suite + Context(fmt.Sprintf("using %s", cipherSuiteName(cs.ID)), func() { getPeers := func(rttStats *congestion.RTTStats) (client, server *updatableAEAD) { trafficSecret1 := make([]byte, 16) @@ -28,10 +27,10 @@ var _ = Describe("Updatable AEAD", func() { client = newUpdatableAEAD(rttStats, utils.DefaultLogger) server = newUpdatableAEAD(rttStats, utils.DefaultLogger) - client.SetReadKey(suite, trafficSecret2) - client.SetWriteKey(suite, trafficSecret1) - server.SetReadKey(suite, trafficSecret1) - server.SetWriteKey(suite, trafficSecret2) + client.SetReadKey(cs, trafficSecret2) + client.SetWriteKey(cs, trafficSecret1) + server.SetReadKey(cs, trafficSecret1) + server.SetWriteKey(cs, trafficSecret2) return }