diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index c5702d35..6ba2d561 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -14,19 +14,14 @@ type sealer struct { // use a single slice to avoid allocations nonceBuf []byte hpMask []byte - - // short headers protect 5 bits in the first byte, long headers only 4 - is1RTT bool } var _ LongHeaderSealer = &sealer{} -var _ ShortHeaderSealer = &sealer{} -func newSealer(aead cipher.AEAD, hpEncrypter cipher.Block, is1RTT bool) ShortHeaderSealer { +func newLongHeaderSealer(aead cipher.AEAD, hpEncrypter cipher.Block) LongHeaderSealer { return &sealer{ aead: aead, nonceBuf: make([]byte, aead.NonceSize()), - is1RTT: is1RTT, hpEncrypter: hpEncrypter, hpMask: make([]byte, hpEncrypter.BlockSize()), } @@ -44,11 +39,7 @@ func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { panic("invalid sample size") } s.hpEncrypter.Encrypt(s.hpMask, sample) - if s.is1RTT { - *firstByte ^= s.hpMask[0] & 0x1f - } else { - *firstByte ^= s.hpMask[0] & 0xf - } + *firstByte ^= s.hpMask[0] & 0xf for i := range pnBytes { pnBytes[i] ^= s.hpMask[i+1] } @@ -58,10 +49,6 @@ func (s *sealer) Overhead() int { return s.aead.Overhead() } -func (s *sealer) KeyPhase() protocol.KeyPhase { - return protocol.KeyPhaseZero -} - type longHeaderOpener struct { aead cipher.AEAD pnDecrypter cipher.Block @@ -99,41 +86,3 @@ func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes pnBytes[i] ^= o.hpMask[i+1] } } - -type shortHeaderOpener struct { - aead cipher.AEAD - pnDecrypter cipher.Block - - // use a single slice to avoid allocations - nonceBuf []byte - hpMask []byte -} - -var _ ShortHeaderOpener = &shortHeaderOpener{} - -func newShortHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) ShortHeaderOpener { - return &shortHeaderOpener{ - aead: aead, - nonceBuf: make([]byte, aead.NonceSize()), - pnDecrypter: pnDecrypter, - hpMask: make([]byte, pnDecrypter.BlockSize()), - } -} - -func (o *shortHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, _ protocol.KeyPhase, ad []byte) ([]byte, error) { - binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - return o.aead.Open(dst, o.nonceBuf, src, ad) -} - -func (o *shortHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - if len(sample) != o.pnDecrypter.BlockSize() { - panic("invalid sample size") - } - o.pnDecrypter.Encrypt(o.hpMask, sample) - *firstByte ^= o.hpMask[0] & 0x1f - for i := range pnBytes { - pnBytes[i] ^= o.hpMask[i+1] - } -} diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index e1f105cc..d211d3a1 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -5,13 +5,12 @@ import ( "crypto/cipher" "crypto/rand" - "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("AEAD", func() { - getLongHeaderSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { + getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { key := make([]byte, 16) hpKey := make([]byte, 16) rand.Read(key) @@ -25,55 +24,30 @@ var _ = Describe("AEAD", func() { iv := make([]byte, 12) rand.Read(iv) - return newSealer(aead, hpBlock, false), newLongHeaderOpener(aead, hpBlock) - } - - getShortHeaderSealerAndOpener := func() (ShortHeaderSealer, ShortHeaderOpener) { - key := make([]byte, 16) - hpKey := make([]byte, 16) - rand.Read(key) - rand.Read(hpKey) - block, err := aes.NewCipher(key) - Expect(err).ToNot(HaveOccurred()) - aead, err := cipher.NewGCM(block) - Expect(err).ToNot(HaveOccurred()) - hpBlock, err := aes.NewCipher(hpKey) - Expect(err).ToNot(HaveOccurred()) - - iv := make([]byte, 12) - rand.Read(iv) - return newSealer(aead, hpBlock, true), newShortHeaderOpener(aead, hpBlock) + return newLongHeaderSealer(aead, hpBlock), newLongHeaderOpener(aead, hpBlock) } Context("message encryption", func() { msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") ad := []byte("Donec in velit neque.") - It("encrypts and decrypts a message, for long headers", func() { - sealer, opener := getLongHeaderSealerAndOpener() + It("encrypts and decrypts a message", func() { + sealer, opener := getSealerAndOpener() encrypted := sealer.Seal(nil, msg, 0x1337, ad) opened, err := opener.Open(nil, encrypted, 0x1337, ad) Expect(err).ToNot(HaveOccurred()) Expect(opened).To(Equal(msg)) }) - It("encrypts and decrypts a message, for short headers", func() { - sealer, opener := getShortHeaderSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - opened, err := opener.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opened).To(Equal(msg)) - }) - It("fails to open a message if the associated data is not the same", func() { - sealer, opener := getLongHeaderSealerAndOpener() + sealer, opener := getSealerAndOpener() encrypted := sealer.Seal(nil, msg, 0x1337, ad) _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) Expect(err).To(MatchError("cipher: message authentication failed")) }) It("fails to open a message if the packet number is not the same", func() { - sealer, opener := getLongHeaderSealerAndOpener() + sealer, opener := getSealerAndOpener() encrypted := sealer.Seal(nil, msg, 0x1337, ad) _, err := opener.Open(nil, encrypted, 0x42, ad) Expect(err).To(MatchError("cipher: message authentication failed")) @@ -81,8 +55,8 @@ var _ = Describe("AEAD", func() { }) Context("header encryption", func() { - It("encrypts and encrypts the header, for long headers", func() { - sealer, opener := getLongHeaderSealerAndOpener() + It("encrypts and encrypts the header", func() { + sealer, opener := getSealerAndOpener() var lastFourBitsDifferent int for i := 0; i < 100; i++ { sample := make([]byte, 16) @@ -101,28 +75,8 @@ var _ = Describe("AEAD", func() { Expect(lastFourBitsDifferent).To(BeNumerically(">", 75)) }) - It("encrypts and encrypts the header, for short headers", func() { - sealer, opener := getShortHeaderSealerAndOpener() - var lastFiveBitsDifferent int - for i := 0; i < 100; i++ { - sample := make([]byte, 16) - rand.Read(sample) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sealer.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0x1f != 0xb5&0x1f { - lastFiveBitsDifferent++ - } - Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) - } - Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) - }) - It("fails to decrypt the header when using a different sample", func() { - sealer, opener := getLongHeaderSealerAndOpener() + sealer, opener := getSealerAndOpener() header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} sample := make([]byte, 16) rand.Read(sample) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 3606b810..48c9bc4f 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -112,9 +112,10 @@ type cryptoSetup struct { handshakeOpener LongHeaderOpener handshakeSealer LongHeaderSealer - oneRTTStream io.Writer - opener ShortHeaderOpener - sealer ShortHeaderSealer + oneRTTStream io.Writer + aead *updatableAEAD + has1RTTSealer bool + has1RTTOpener bool } var _ qtls.RecordLayer = &cryptoSetup{} @@ -202,6 +203,7 @@ func newCryptoSetup( initialOpener: initialOpener, handshakeStream: handshakeStream, oneRTTStream: oneRTTStream, + aead: newUpdatableAEAD(), readEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial, runner: runner, @@ -497,7 +499,8 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) h.logger.Debugf("Installed Handshake Read keys") case protocol.EncryptionHandshake: h.readEncLevel = protocol.Encryption1RTT - h.opener = newShortHeaderOpener(suite.AEAD(key, iv), hpDecrypter) + h.aead.SetReadKey(suite.AEAD(key, iv), hpDecrypter) + h.has1RTTOpener = true h.logger.Debugf("Installed 1-RTT Read keys") default: panic("unexpected read encryption level") @@ -519,11 +522,12 @@ func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) switch h.writeEncLevel { case protocol.EncryptionInitial: h.writeEncLevel = protocol.EncryptionHandshake - h.handshakeSealer = newSealer(suite.AEAD(key, iv), hpEncrypter, false) + h.handshakeSealer = newLongHeaderSealer(suite.AEAD(key, iv), hpEncrypter) h.logger.Debugf("Installed Handshake Write keys") case protocol.EncryptionHandshake: h.writeEncLevel = protocol.Encryption1RTT - h.sealer = newSealer(suite.AEAD(key, iv), hpEncrypter, true) + h.aead.SetWriteKey(suite.AEAD(key, iv), hpEncrypter) + h.has1RTTSealer = true h.logger.Debugf("Installed 1-RTT Write keys") default: panic("unexpected write encryption level") @@ -585,10 +589,10 @@ func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { h.mutex.Lock() defer h.mutex.Unlock() - if h.sealer == nil { + if !h.has1RTTSealer { return nil, errors.New("CryptoSetup: no sealer with encryption level 1-RTT") } - return h.sealer, nil + return h.aead, nil } func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { @@ -619,10 +623,10 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { h.mutex.Lock() defer h.mutex.Unlock() - if h.opener == nil { + if !h.has1RTTOpener { return nil, ErrOpenerNotYetAvailable } - return h.opener, nil + return h.aead, nil } func (h *cryptoSetup) ConnectionState() tls.ConnectionState { diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index fcb08e15..28ea2a95 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -34,7 +34,7 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo if err != nil { return nil, nil, err } - return newSealer(encrypter, hpEncrypter, false), newLongHeaderOpener(decrypter, hpDecrypter), nil + return newLongHeaderSealer(encrypter, hpEncrypter), newLongHeaderOpener(decrypter, hpDecrypter), nil } func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) { diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go new file mode 100644 index 00000000..591f6677 --- /dev/null +++ b/internal/handshake/updatable_aead.go @@ -0,0 +1,93 @@ +package handshake + +import ( + "crypto/cipher" + "encoding/binary" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type shortHeaderOpener struct { + aead cipher.AEAD + hpDecrypter cipher.Block + + // use a single slice to avoid allocations + nonceBuf []byte + hpMask []byte +} + +var _ ShortHeaderOpener = &shortHeaderOpener{} + +func newShortHeaderOpener(aead cipher.AEAD, hpDecrypter cipher.Block) ShortHeaderOpener { + return &shortHeaderOpener{ + aead: aead, + nonceBuf: make([]byte, aead.NonceSize()), + hpDecrypter: hpDecrypter, + hpMask: make([]byte, hpDecrypter.BlockSize()), + } +} + +func (o *shortHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, _ protocol.KeyPhase, ad []byte) ([]byte, error) { + binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + return o.aead.Open(dst, o.nonceBuf, src, ad) +} + +func (o *shortHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { + if len(sample) != o.hpDecrypter.BlockSize() { + panic("invalid sample size") + } + o.hpDecrypter.Encrypt(o.hpMask, sample) + *firstByte ^= o.hpMask[0] & 0x1f + for i := range pnBytes { + pnBytes[i] ^= o.hpMask[i+1] + } +} + +type shortHeaderSealer struct { + sealer +} + +func newShortHeaderSealer(aead cipher.AEAD, hpEncrypter cipher.Block) ShortHeaderSealer { + return &shortHeaderSealer{ + sealer: sealer{ + aead: aead, + nonceBuf: make([]byte, aead.NonceSize()), + hpEncrypter: hpEncrypter, + hpMask: make([]byte, hpEncrypter.BlockSize()), + }, + } +} + +func (s *shortHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { + if len(sample) != s.hpEncrypter.BlockSize() { + panic("invalid sample size") + } + s.hpEncrypter.Encrypt(s.hpMask, sample) + *firstByte ^= s.hpMask[0] & 0x1f + for i := range pnBytes { + pnBytes[i] ^= s.hpMask[i+1] + } +} + +func (s *shortHeaderSealer) KeyPhase() protocol.KeyPhase { + return protocol.KeyPhaseOne +} + +type updatableAEAD struct { + ShortHeaderOpener + ShortHeaderSealer +} + +func newUpdatableAEAD() *updatableAEAD { + return &updatableAEAD{} +} + +func (a *updatableAEAD) SetReadKey(aead cipher.AEAD, hpDecrypter cipher.Block) { + a.ShortHeaderOpener = newShortHeaderOpener(aead, hpDecrypter) +} + +func (a *updatableAEAD) SetWriteKey(aead cipher.AEAD, hpDecrypter cipher.Block) { + a.ShortHeaderSealer = newShortHeaderSealer(aead, hpDecrypter) +} diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go new file mode 100644 index 00000000..19edd02a --- /dev/null +++ b/internal/handshake/updatable_aead_test.go @@ -0,0 +1,83 @@ +package handshake + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + + "github.com/lucas-clemente/quic-go/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Updatable AEAD", func() { + getAEAD := func() *updatableAEAD { + key := make([]byte, 16) + hpKey := make([]byte, 16) + rand.Read(key) + rand.Read(hpKey) + block, err := aes.NewCipher(key) + Expect(err).ToNot(HaveOccurred()) + aead, err := cipher.NewGCM(block) + Expect(err).ToNot(HaveOccurred()) + hpBlock, err := aes.NewCipher(hpKey) + Expect(err).ToNot(HaveOccurred()) + + iv := make([]byte, 12) + rand.Read(iv) + + u := newUpdatableAEAD() + u.SetReadKey(aead, hpBlock) + u.SetWriteKey(aead, hpBlock) + return u + } + + Context("message encryption", func() { + msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad := []byte("Donec in velit neque.") + + It("encrypts and decrypts a message", func() { + aead := getAEAD() + encrypted := aead.Seal(nil, msg, 0x1337, ad) + opened, err := aead.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) + }) + + It("fails to open a message if the associated data is not the same", func() { + aead := getAEAD() + encrypted := aead.Seal(nil, msg, 0x1337, ad) + _, err := aead.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) + Expect(err).To(MatchError("cipher: message authentication failed")) + }) + + It("fails to open a message if the packet number is not the same", func() { + aead := getAEAD() + encrypted := aead.Seal(nil, msg, 0x1337, ad) + _, err := aead.Open(nil, encrypted, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError("cipher: message authentication failed")) + }) + }) + + Context("header encryption", func() { + It("encrypts and decrypts the header", func() { + aead := getAEAD() + var lastFiveBitsDifferent int + for i := 0; i < 100; i++ { + sample := make([]byte, 16) + rand.Read(sample) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + aead.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0x1f != 0xb5&0x1f { + lastFiveBitsDifferent++ + } + Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + aead.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) + }) + }) +})