From e74ede678f94518e077170d8e71d69936cf6d27b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 11 Jun 2019 11:43:35 +0800 Subject: [PATCH] move opening / sealing to the updatable AEAD --- internal/handshake/updatable_aead.go | 83 ++++++++++++++-------------- 1 file changed, 40 insertions(+), 43 deletions(-) diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 6beb5b9df..3cf029977 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -7,53 +7,16 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" ) -type shortHeaderOpener struct { - aead cipher.AEAD - - // use a single slice to avoid allocations - nonceBuf []byte -} - -func newShortHeaderOpener(aead cipher.AEAD) *shortHeaderOpener { - return &shortHeaderOpener{ - aead: aead, - nonceBuf: make([]byte, aead.NonceSize()), - } -} - -func (o *shortHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, _ protocol.KeyPhase, ad []byte) ([]byte, error) { - binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - return o.aead.Open(dst, o.nonceBuf, src, ad) -} - -type shortHeaderSealer struct { - sealer -} - -func newShortHeaderSealer(aead cipher.AEAD) *shortHeaderSealer { - return &shortHeaderSealer{ - sealer: sealer{ - aead: aead, - nonceBuf: make([]byte, aead.NonceSize()), - }, - } -} - -func (s *shortHeaderSealer) KeyPhase() protocol.KeyPhase { - return protocol.KeyPhaseOne -} - type updatableAEAD struct { - *shortHeaderSealer - *shortHeaderOpener + openAEAD cipher.AEAD + sealAEAD cipher.AEAD hpDecrypter cipher.Block hpEncrypter cipher.Block // use a single slice to avoid allocations - hpMask []byte + nonceBuf []byte + hpMask []byte } var _ ShortHeaderOpener = &updatableAEAD{} @@ -65,7 +28,13 @@ func newUpdatableAEAD() *updatableAEAD { func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) { aead, hpDecrypter := createAEAD(suite, trafficSecret) - a.shortHeaderOpener = newShortHeaderOpener(aead) + if len(a.nonceBuf) == 0 { + a.nonceBuf = make([]byte, aead.NonceSize()) + } else if len(a.nonceBuf) != aead.NonceSize() { + panic("invalid nonce size") + } + a.openAEAD = aead + if len(a.hpMask) == 0 { a.hpMask = make([]byte, hpDecrypter.BlockSize()) } else if len(a.hpMask) != hpDecrypter.BlockSize() { @@ -76,7 +45,13 @@ func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) { func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) { aead, hpEncrypter := createAEAD(suite, trafficSecret) - a.shortHeaderSealer = newShortHeaderSealer(aead) + if len(a.nonceBuf) == 0 { + a.nonceBuf = make([]byte, aead.NonceSize()) + } else if len(a.nonceBuf) != aead.NonceSize() { + panic("invalid nonce size") + } + a.sealAEAD = aead + if len(a.hpMask) == 0 { a.hpMask = make([]byte, hpEncrypter.BlockSize()) } else if len(a.hpMask) != hpEncrypter.BlockSize() { @@ -85,6 +60,28 @@ func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) { a.hpEncrypter = hpEncrypter } +func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, _ protocol.KeyPhase, ad []byte) ([]byte, error) { + binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + return a.openAEAD.Open(dst, a.nonceBuf, src, ad) +} + +func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { + binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + return a.sealAEAD.Seal(dst, a.nonceBuf, src, ad) +} + +func (a *updatableAEAD) KeyPhase() protocol.KeyPhase { + return protocol.KeyPhaseOne +} + +func (a *updatableAEAD) Overhead() int { + return a.sealAEAD.Overhead() +} + func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { if len(sample) != a.hpEncrypter.BlockSize() { panic("invalid sample size")