From 5a68ba0a024f279895185bf86231f869262e27c5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 14 Dec 2018 16:42:45 +0630 Subject: [PATCH] implement header encrytion and decryption for sealers and openers --- internal/handshake/aead.go | 70 ++++++++++++--- internal/handshake/aead_test.go | 112 +++++++++++++++++++----- internal/handshake/crypto_setup.go | 25 +++++- internal/handshake/initial_aead.go | 14 ++- internal/handshake/initial_aead_test.go | 40 ++++++++- internal/handshake/interface.go | 2 + internal/mocks/opener.go | 10 +++ internal/mocks/sealer.go | 10 +++ 8 files changed, 239 insertions(+), 44 deletions(-) diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index 21d61a8f..21a5b921 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -8,20 +8,28 @@ import ( ) type sealer struct { - iv []byte - aead cipher.AEAD + iv []byte + aead cipher.AEAD + pnEncrypter cipher.Block // use a single slice to avoid allocations nonceBuf []byte + pnMask []byte + + // short headers protect 5 bits in the first byte, long headers only 4 + is1RTT bool } var _ Sealer = &sealer{} -func newSealer(aead cipher.AEAD, iv []byte) Sealer { +func newSealer(aead cipher.AEAD, iv []byte, pnEncrypter cipher.Block, is1RTT bool) Sealer { return &sealer{ - iv: iv, - aead: aead, - nonceBuf: make([]byte, aead.NonceSize()), + iv: iv, + aead: aead, + nonceBuf: make([]byte, aead.NonceSize()), + is1RTT: is1RTT, + pnEncrypter: pnEncrypter, + pnMask: make([]byte, pnEncrypter.BlockSize()), } } @@ -30,25 +38,48 @@ func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []by return s.aead.Seal(dst, s.nonceBuf, src, ad) } +func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { + if len(sample) != s.pnEncrypter.BlockSize() { + panic("invalid sample size") + } + s.pnEncrypter.Encrypt(s.pnMask, sample) + if s.is1RTT { + *firstByte ^= s.pnMask[0] & 0x1f + } else { + *firstByte ^= s.pnMask[0] & 0xf + } + for i := range pnBytes { + pnBytes[i] ^= s.pnMask[i+1] + } +} + func (s *sealer) Overhead() int { return s.aead.Overhead() } type opener struct { - iv []byte - aead cipher.AEAD + iv []byte + aead cipher.AEAD + pnDecrypter cipher.Block // use a single slice to avoid allocations nonceBuf []byte + pnMask []byte + + // short headers protect 5 bits in the first byte, long headers only 4 + is1RTT bool } var _ Opener = &opener{} -func newOpener(aead cipher.AEAD, iv []byte) Opener { +func newOpener(aead cipher.AEAD, iv []byte, pnDecrypter cipher.Block, is1RTT bool) Opener { return &opener{ - iv: iv, - aead: aead, - nonceBuf: make([]byte, aead.NonceSize()), + iv: iv, + aead: aead, + nonceBuf: make([]byte, aead.NonceSize()), + is1RTT: is1RTT, + pnDecrypter: pnDecrypter, + pnMask: make([]byte, pnDecrypter.BlockSize()), } } @@ -56,3 +87,18 @@ func (o *opener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]b binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) return o.aead.Open(dst, o.nonceBuf, src, ad) } + +func (o *opener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { + if len(sample) != o.pnDecrypter.BlockSize() { + panic("invalid sample size") + } + o.pnDecrypter.Encrypt(o.pnMask, sample) + if o.is1RTT { + *firstByte ^= o.pnMask[0] & 0x1f + } else { + *firstByte ^= o.pnMask[0] & 0xf + } + for i := range pnBytes { + pnBytes[i] ^= o.pnMask[i+1] + } +} diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index 3556630c..3aa2580c 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -10,42 +10,106 @@ import ( ) var _ = Describe("AEAD", func() { - var sealer Sealer - var opener Opener - - msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - ad := []byte("Donec in velit neque.") - - BeforeEach(func() { + getSealerAndOpener := func(is1RTT bool) (Sealer, Opener) { key := make([]byte, 16) + pnKey := make([]byte, 16) rand.Read(key) + rand.Read(pnKey) block, err := aes.NewCipher(key) Expect(err).ToNot(HaveOccurred()) aead, err := cipher.NewGCM(block) Expect(err).ToNot(HaveOccurred()) + pnBlock, err := aes.NewCipher(pnKey) + Expect(err).ToNot(HaveOccurred()) iv := make([]byte, 12) rand.Read(iv) - sealer = newSealer(aead, iv) - opener = newOpener(aead, iv) + return newSealer(aead, iv, pnBlock, is1RTT), newOpener(aead, iv, pnBlock, is1RTT) + } + + Context("message encryption", func() { + var ( + sealer Sealer + opener Opener + ) + + BeforeEach(func() { + sealer, opener = getSealerAndOpener(false) + }) + + msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad := []byte("Donec in velit neque.") + + It("encrypts and decrypts a message", func() { + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + opened, err := opener.Open(nil, encrypted, 0x1337, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) + }) + + It("fails to open a message if the associated data is not the same", func() { + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) + Expect(err).To(MatchError("cipher: message authentication failed")) + }) + + It("fails to open a message if the packet number is not the same", func() { + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x42, ad) + Expect(err).To(MatchError("cipher: message authentication failed")) + }) }) - It("encrypts and decrypts a message", func() { - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - opened, err := opener.Open(nil, encrypted, 0x1337, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opened).To(Equal(msg)) - }) + Context("header encryption", func() { + It("encrypts and encrypts the header, for long headers", func() { + sealer, opener := getSealerAndOpener(false) + var lastFourBitsDifferent int + for i := 0; i < 100; i++ { + sample := make([]byte, 16) + rand.Read(sample) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sealer.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0xf != 0xb5&0xf { + lastFourBitsDifferent++ + } + Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + Expect(lastFourBitsDifferent).To(BeNumerically(">", 75)) + }) - It("fails to open a message if the associated data is not the same", func() { - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) - Expect(err).To(MatchError("cipher: message authentication failed")) - }) + It("encrypts and encrypts the header, for short headers", func() { + sealer, opener := getSealerAndOpener(true) + var lastFiveBitsDifferent int + for i := 0; i < 100; i++ { + sample := make([]byte, 16) + rand.Read(sample) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sealer.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0x1f != 0xb5&0x1f { + lastFiveBitsDifferent++ + } + Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) + }) - It("fails to open a message if the packet number is not the same", func() { - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x42, ad) - Expect(err).To(MatchError("cipher: message authentication failed")) + It("fails to decrypt the header when using a different sample", func() { + sealer, opener := getSealerAndOpener(true) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sample := make([]byte, 16) + rand.Read(sample) + sealer.EncryptHeader(sample, &header[0], header[9:13]) + rand.Read(sample) // use a different sample + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + }) }) }) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index f36f3e93..421c99ca 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -1,6 +1,7 @@ package handshake import ( + "crypto/aes" "crypto/tls" "errors" "fmt" @@ -407,7 +408,17 @@ func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) { func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) { key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen()) iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen()) - opener := newOpener(suite.AEAD(key, iv), iv) + pnKey := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "pn", suite.KeyLen()) + pnDecrypter, err := aes.NewCipher(pnKey) + if err != nil { + panic(fmt.Sprintf("error creating new AES cipher: %s", err)) + } + opener := newOpener( + suite.AEAD(key, iv), + iv, + pnDecrypter, + h.readEncLevel == protocol.Encryption1RTT, + ) switch h.readEncLevel { case protocol.EncryptionInitial: @@ -427,7 +438,17 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) { key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen()) iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen()) - sealer := newSealer(suite.AEAD(key, iv), iv) + pnKey := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "pn", suite.KeyLen()) + pnEncrypter, err := aes.NewCipher(pnKey) + if err != nil { + panic(fmt.Sprintf("error creating new AES cipher: %s", err)) + } + sealer := newSealer( + suite.AEAD(key, iv), + iv, + pnEncrypter, + h.writeEncLevel == protocol.Encryption1RTT, + ) switch h.writeEncLevel { case protocol.EncryptionInitial: diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index fa97156c..222072af 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -21,8 +21,8 @@ func newInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Se mySecret = serverSecret otherSecret = clientSecret } - myKey, _, myIV := computeInitialKeyAndIV(mySecret) - otherKey, _, otherIV := computeInitialKeyAndIV(otherSecret) + myKey, myPNKey, myIV := computeInitialKeyAndIV(mySecret) + otherKey, otherPNKey, otherIV := computeInitialKeyAndIV(otherSecret) encrypterCipher, err := aes.NewCipher(myKey) if err != nil { @@ -32,6 +32,10 @@ func newInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Se if err != nil { return nil, nil, err } + pnEncrypter, err := aes.NewCipher(myPNKey) + if err != nil { + return nil, nil, err + } decrypterCipher, err := aes.NewCipher(otherKey) if err != nil { return nil, nil, err @@ -40,7 +44,11 @@ func newInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Se if err != nil { return nil, nil, err } - return newSealer(encrypter, myIV), newOpener(decrypter, otherIV), nil + pnDecrypter, err := aes.NewCipher(otherPNKey) + if err != nil { + return nil, nil, err + } + return newSealer(encrypter, myIV, pnEncrypter, false), newOpener(decrypter, otherIV, pnDecrypter, false), nil } func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) { diff --git a/internal/handshake/initial_aead_test.go b/internal/handshake/initial_aead_test.go index 1001354f..f05312a1 100644 --- a/internal/handshake/initial_aead_test.go +++ b/internal/handshake/initial_aead_test.go @@ -1,6 +1,8 @@ package handshake import ( + "math/rand" + "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" @@ -64,7 +66,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { }) It("seals and opens", func() { - connectionID := protocol.ConnectionID([]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef}) + connectionID := protocol.ConnectionID{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef} clientSealer, clientOpener, err := newInitialAEAD(connectionID, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) serverSealer, serverOpener, err := newInitialAEAD(connectionID, protocol.PerspectiveServer) @@ -81,8 +83,8 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { }) It("doesn't work if initialized with different connection IDs", func() { - c1 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 1}) - c2 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 2}) + c1 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 1} + c2 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 2} clientSealer, _, err := newInitialAEAD(c1, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) _, serverOpener, err := newInitialAEAD(c2, protocol.PerspectiveServer) @@ -92,4 +94,36 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { _, err = serverOpener.Open(nil, clientMessage, 42, []byte("aad")) Expect(err).To(MatchError("cipher: message authentication failed")) }) + + It("encrypts und decrypts the header", func() { + connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + clientSealer, clientOpener, err := newInitialAEAD(connID, protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) + serverSealer, serverOpener, err := newInitialAEAD(connID, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + + // the first byte and the last 4 bytes should be encrypted + header := []byte{0x5e, 0, 1, 2, 3, 4, 0xde, 0xad, 0xbe, 0xef} + sample := make([]byte, 16) + rand.Read(sample) + clientSealer.EncryptHeader(sample, &header[0], header[6:10]) + // only the last 4 bits of the first byte are encrypted. Check that the first 4 bits are unmodified + Expect(header[0] & 0xf0).To(Equal(byte(0x5e & 0xf0))) + Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) + Expect(header[6:10]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + serverOpener.DecryptHeader(sample, &header[0], header[6:10]) + Expect(header[0]).To(Equal(byte(0x5e))) + Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) + Expect(header[6:10]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + + serverSealer.EncryptHeader(sample, &header[0], header[6:10]) + // only the last 4 bits of the first byte are encrypted. Check that the first 4 bits are unmodified + Expect(header[0] & 0xf0).To(Equal(byte(0x5e & 0xf0))) + Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) + Expect(header[6:10]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + clientOpener.DecryptHeader(sample, &header[0], header[6:10]) + Expect(header[0]).To(Equal(byte(0x5e))) + Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4})) + Expect(header[6:10]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + }) }) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 521af784..38d8e4a6 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -11,11 +11,13 @@ import ( // Opener opens a packet type Opener interface { Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) + DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) } // Sealer seals a packet type Sealer interface { Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte + EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) Overhead() int } diff --git a/internal/mocks/opener.go b/internal/mocks/opener.go index 0268a4a4..dd927289 100644 --- a/internal/mocks/opener.go +++ b/internal/mocks/opener.go @@ -34,6 +34,16 @@ func (m *MockOpener) EXPECT() *MockOpenerMockRecorder { return m.recorder } +// DecryptHeader mocks base method +func (m *MockOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { + m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2) +} + +// DecryptHeader indicates an expected call of DecryptHeader +func (mr *MockOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockOpener)(nil).DecryptHeader), arg0, arg1, arg2) +} + // Open mocks base method func (m *MockOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3) diff --git a/internal/mocks/sealer.go b/internal/mocks/sealer.go index 8b2d1f6e..de309e16 100644 --- a/internal/mocks/sealer.go +++ b/internal/mocks/sealer.go @@ -34,6 +34,16 @@ func (m *MockSealer) EXPECT() *MockSealerMockRecorder { return m.recorder } +// EncryptHeader mocks base method +func (m *MockSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { + m.ctrl.Call(m, "EncryptHeader", arg0, arg1, arg2) +} + +// EncryptHeader indicates an expected call of EncryptHeader +func (mr *MockSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockSealer)(nil).EncryptHeader), arg0, arg1, arg2) +} + // Overhead mocks base method func (m *MockSealer) Overhead() int { ret := m.ctrl.Call(m, "Overhead")