diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index 6ba2d561..673544e2 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -1,10 +1,13 @@ package handshake import ( + "crypto/aes" "crypto/cipher" "encoding/binary" + "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/marten-seemann/qtls" ) type sealer struct { @@ -86,3 +89,14 @@ func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes pnBytes[i] ^= o.hpMask[i+1] } } + +func createAEAD(suite cipherSuite, trafficSecret []byte) (cipher.AEAD, cipher.Block) { + key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic key", suite.KeyLen()) + iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen()) + hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen()) + hpDecrypter, err := aes.NewCipher(hpKey) + if err != nil { + panic(fmt.Sprintf("error creating new AES cipher: %s", err)) + } + return suite.AEAD(key, iv), hpDecrypter +} diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 48c9bc4f..086e46b6 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -1,7 +1,6 @@ package handshake import ( - "crypto/aes" "crypto/tls" "errors" "fmt" @@ -483,23 +482,15 @@ func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) { } func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) { - key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic key", suite.KeyLen()) - iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen()) - hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen()) - hpDecrypter, err := aes.NewCipher(hpKey) - if err != nil { - panic(fmt.Sprintf("error creating new AES cipher: %s", err)) - } - h.mutex.Lock() switch h.readEncLevel { case protocol.EncryptionInitial: h.readEncLevel = protocol.EncryptionHandshake - h.handshakeOpener = newLongHeaderOpener(suite.AEAD(key, iv), hpDecrypter) + h.handshakeOpener = newLongHeaderOpener(createAEAD(suite, trafficSecret)) h.logger.Debugf("Installed Handshake Read keys") case protocol.EncryptionHandshake: h.readEncLevel = protocol.Encryption1RTT - h.aead.SetReadKey(suite.AEAD(key, iv), hpDecrypter) + h.aead.SetReadKey(suite, trafficSecret) h.has1RTTOpener = true h.logger.Debugf("Installed 1-RTT Read keys") default: @@ -510,23 +501,15 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) } func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) { - key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic key", suite.KeyLen()) - iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen()) - hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen()) - hpEncrypter, err := aes.NewCipher(hpKey) - if err != nil { - panic(fmt.Sprintf("error creating new AES cipher: %s", err)) - } - h.mutex.Lock() switch h.writeEncLevel { case protocol.EncryptionInitial: h.writeEncLevel = protocol.EncryptionHandshake - h.handshakeSealer = newLongHeaderSealer(suite.AEAD(key, iv), hpEncrypter) + h.handshakeSealer = newLongHeaderSealer(createAEAD(suite, trafficSecret)) h.logger.Debugf("Installed Handshake Write keys") case protocol.EncryptionHandshake: h.writeEncLevel = protocol.Encryption1RTT - h.aead.SetWriteKey(suite.AEAD(key, iv), hpEncrypter) + h.aead.SetWriteKey(suite, trafficSecret) h.has1RTTSealer = true h.logger.Debugf("Installed 1-RTT Write keys") default: diff --git a/internal/handshake/qtls.go b/internal/handshake/qtls.go index cd093142..52d0723f 100644 --- a/internal/handshake/qtls.go +++ b/internal/handshake/qtls.go @@ -1,6 +1,8 @@ package handshake import ( + "crypto" + "crypto/cipher" "crypto/tls" "net" "time" @@ -9,6 +11,13 @@ import ( "github.com/marten-seemann/qtls" ) +type cipherSuite interface { + Hash() crypto.Hash + KeyLen() int + IVLen() int + AEAD(key, nonce []byte) cipher.AEAD +} + type conn struct { remoteAddr net.Addr } diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 591f6677..1943da0c 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -84,10 +84,10 @@ func newUpdatableAEAD() *updatableAEAD { return &updatableAEAD{} } -func (a *updatableAEAD) SetReadKey(aead cipher.AEAD, hpDecrypter cipher.Block) { - a.ShortHeaderOpener = newShortHeaderOpener(aead, hpDecrypter) +func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) { + a.ShortHeaderOpener = newShortHeaderOpener(createAEAD(suite, trafficSecret)) } -func (a *updatableAEAD) SetWriteKey(aead cipher.AEAD, hpDecrypter cipher.Block) { - a.ShortHeaderSealer = newShortHeaderSealer(aead, hpDecrypter) +func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) { + a.ShortHeaderSealer = newShortHeaderSealer(createAEAD(suite, trafficSecret)) } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 19edd02a..05626de1 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -1,6 +1,7 @@ package handshake import ( + "crypto" "crypto/aes" "crypto/cipher" "crypto/rand" @@ -10,25 +11,29 @@ import ( . "github.com/onsi/gomega" ) +type mockCipherSuite struct{} + +var _ cipherSuite = &mockCipherSuite{} + +func (c *mockCipherSuite) Hash() crypto.Hash { return crypto.SHA256 } +func (c *mockCipherSuite) KeyLen() int { return 16 } +func (c *mockCipherSuite) IVLen() int { return 12 } +func (c *mockCipherSuite) AEAD(key, _ []byte) cipher.AEAD { + block, err := aes.NewCipher(key) + Expect(err).ToNot(HaveOccurred()) + gcm, err := cipher.NewGCM(block) + Expect(err).ToNot(HaveOccurred()) + return gcm +} + var _ = Describe("Updatable AEAD", func() { getAEAD := func() *updatableAEAD { - key := make([]byte, 16) - hpKey := make([]byte, 16) - rand.Read(key) - rand.Read(hpKey) - block, err := aes.NewCipher(key) - Expect(err).ToNot(HaveOccurred()) - aead, err := cipher.NewGCM(block) - Expect(err).ToNot(HaveOccurred()) - hpBlock, err := aes.NewCipher(hpKey) - Expect(err).ToNot(HaveOccurred()) - - iv := make([]byte, 12) - rand.Read(iv) + trafficSecret := make([]byte, 16) + rand.Read(trafficSecret) u := newUpdatableAEAD() - u.SetReadKey(aead, hpBlock) - u.SetWriteKey(aead, hpBlock) + u.SetReadKey(&mockCipherSuite{}, trafficSecret) + u.SetWriteKey(&mockCipherSuite{}, trafficSecret) return u }