diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 1943da0cb..6beb5b9df 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -8,22 +8,16 @@ import ( ) type shortHeaderOpener struct { - aead cipher.AEAD - hpDecrypter cipher.Block + aead cipher.AEAD // use a single slice to avoid allocations nonceBuf []byte - hpMask []byte } -var _ ShortHeaderOpener = &shortHeaderOpener{} - -func newShortHeaderOpener(aead cipher.AEAD, hpDecrypter cipher.Block) ShortHeaderOpener { +func newShortHeaderOpener(aead cipher.AEAD) *shortHeaderOpener { return &shortHeaderOpener{ - aead: aead, - nonceBuf: make([]byte, aead.NonceSize()), - hpDecrypter: hpDecrypter, - hpMask: make([]byte, hpDecrypter.BlockSize()), + aead: aead, + nonceBuf: make([]byte, aead.NonceSize()), } } @@ -34,60 +28,81 @@ func (o *shortHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, _ pr return o.aead.Open(dst, o.nonceBuf, src, ad) } -func (o *shortHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - if len(sample) != o.hpDecrypter.BlockSize() { - panic("invalid sample size") - } - o.hpDecrypter.Encrypt(o.hpMask, sample) - *firstByte ^= o.hpMask[0] & 0x1f - for i := range pnBytes { - pnBytes[i] ^= o.hpMask[i+1] - } -} - type shortHeaderSealer struct { sealer } -func newShortHeaderSealer(aead cipher.AEAD, hpEncrypter cipher.Block) ShortHeaderSealer { +func newShortHeaderSealer(aead cipher.AEAD) *shortHeaderSealer { return &shortHeaderSealer{ sealer: sealer{ - aead: aead, - nonceBuf: make([]byte, aead.NonceSize()), - hpEncrypter: hpEncrypter, - hpMask: make([]byte, hpEncrypter.BlockSize()), + aead: aead, + nonceBuf: make([]byte, aead.NonceSize()), }, } } -func (s *shortHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - if len(sample) != s.hpEncrypter.BlockSize() { - panic("invalid sample size") - } - s.hpEncrypter.Encrypt(s.hpMask, sample) - *firstByte ^= s.hpMask[0] & 0x1f - for i := range pnBytes { - pnBytes[i] ^= s.hpMask[i+1] - } -} - func (s *shortHeaderSealer) KeyPhase() protocol.KeyPhase { return protocol.KeyPhaseOne } type updatableAEAD struct { - ShortHeaderOpener - ShortHeaderSealer + *shortHeaderSealer + *shortHeaderOpener + + hpDecrypter cipher.Block + hpEncrypter cipher.Block + + // use a single slice to avoid allocations + hpMask []byte } +var _ ShortHeaderOpener = &updatableAEAD{} +var _ ShortHeaderSealer = &updatableAEAD{} + func newUpdatableAEAD() *updatableAEAD { return &updatableAEAD{} } func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) { - a.ShortHeaderOpener = newShortHeaderOpener(createAEAD(suite, trafficSecret)) + aead, hpDecrypter := createAEAD(suite, trafficSecret) + a.shortHeaderOpener = newShortHeaderOpener(aead) + if len(a.hpMask) == 0 { + a.hpMask = make([]byte, hpDecrypter.BlockSize()) + } else if len(a.hpMask) != hpDecrypter.BlockSize() { + panic("invalid header protection block size") + } + a.hpDecrypter = hpDecrypter } func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) { - a.ShortHeaderSealer = newShortHeaderSealer(createAEAD(suite, trafficSecret)) + aead, hpEncrypter := createAEAD(suite, trafficSecret) + a.shortHeaderSealer = newShortHeaderSealer(aead) + if len(a.hpMask) == 0 { + a.hpMask = make([]byte, hpEncrypter.BlockSize()) + } else if len(a.hpMask) != hpEncrypter.BlockSize() { + panic("invalid header protection block size") + } + a.hpEncrypter = hpEncrypter +} + +func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { + if len(sample) != a.hpEncrypter.BlockSize() { + panic("invalid sample size") + } + a.hpEncrypter.Encrypt(a.hpMask, sample) + *firstByte ^= a.hpMask[0] & 0x1f + for i := range pnBytes { + pnBytes[i] ^= a.hpMask[i+1] + } +} + +func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { + if len(sample) != a.hpDecrypter.BlockSize() { + panic("invalid sample size") + } + a.hpDecrypter.Encrypt(a.hpMask, sample) + *firstByte ^= a.hpMask[0] & 0x1f + for i := range pnBytes { + pnBytes[i] ^= a.hpMask[i+1] + } }