forked from quic-go/quic-go
refactor initialization of AEADs
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
type sealer struct {
|
||||
@@ -86,3 +89,14 @@ func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes
|
||||
pnBytes[i] ^= o.hpMask[i+1]
|
||||
}
|
||||
}
|
||||
|
||||
func createAEAD(suite cipherSuite, trafficSecret []byte) (cipher.AEAD, cipher.Block) {
|
||||
key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic key", suite.KeyLen())
|
||||
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen())
|
||||
hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen())
|
||||
hpDecrypter, err := aes.NewCipher(hpKey)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
||||
}
|
||||
return suite.AEAD(key, iv), hpDecrypter
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -483,23 +482,15 @@ func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||
key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic key", suite.KeyLen())
|
||||
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen())
|
||||
hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen())
|
||||
hpDecrypter, err := aes.NewCipher(hpKey)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
||||
}
|
||||
|
||||
h.mutex.Lock()
|
||||
switch h.readEncLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
h.readEncLevel = protocol.EncryptionHandshake
|
||||
h.handshakeOpener = newLongHeaderOpener(suite.AEAD(key, iv), hpDecrypter)
|
||||
h.handshakeOpener = newLongHeaderOpener(createAEAD(suite, trafficSecret))
|
||||
h.logger.Debugf("Installed Handshake Read keys")
|
||||
case protocol.EncryptionHandshake:
|
||||
h.readEncLevel = protocol.Encryption1RTT
|
||||
h.aead.SetReadKey(suite.AEAD(key, iv), hpDecrypter)
|
||||
h.aead.SetReadKey(suite, trafficSecret)
|
||||
h.has1RTTOpener = true
|
||||
h.logger.Debugf("Installed 1-RTT Read keys")
|
||||
default:
|
||||
@@ -510,23 +501,15 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte)
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||
key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic key", suite.KeyLen())
|
||||
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen())
|
||||
hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen())
|
||||
hpEncrypter, err := aes.NewCipher(hpKey)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
||||
}
|
||||
|
||||
h.mutex.Lock()
|
||||
switch h.writeEncLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
h.writeEncLevel = protocol.EncryptionHandshake
|
||||
h.handshakeSealer = newLongHeaderSealer(suite.AEAD(key, iv), hpEncrypter)
|
||||
h.handshakeSealer = newLongHeaderSealer(createAEAD(suite, trafficSecret))
|
||||
h.logger.Debugf("Installed Handshake Write keys")
|
||||
case protocol.EncryptionHandshake:
|
||||
h.writeEncLevel = protocol.Encryption1RTT
|
||||
h.aead.SetWriteKey(suite.AEAD(key, iv), hpEncrypter)
|
||||
h.aead.SetWriteKey(suite, trafficSecret)
|
||||
h.has1RTTSealer = true
|
||||
h.logger.Debugf("Installed 1-RTT Write keys")
|
||||
default:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
@@ -9,6 +11,13 @@ import (
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
type cipherSuite interface {
|
||||
Hash() crypto.Hash
|
||||
KeyLen() int
|
||||
IVLen() int
|
||||
AEAD(key, nonce []byte) cipher.AEAD
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
@@ -84,10 +84,10 @@ func newUpdatableAEAD() *updatableAEAD {
|
||||
return &updatableAEAD{}
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) SetReadKey(aead cipher.AEAD, hpDecrypter cipher.Block) {
|
||||
a.ShortHeaderOpener = newShortHeaderOpener(aead, hpDecrypter)
|
||||
func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) {
|
||||
a.ShortHeaderOpener = newShortHeaderOpener(createAEAD(suite, trafficSecret))
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) SetWriteKey(aead cipher.AEAD, hpDecrypter cipher.Block) {
|
||||
a.ShortHeaderSealer = newShortHeaderSealer(aead, hpDecrypter)
|
||||
func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
|
||||
a.ShortHeaderSealer = newShortHeaderSealer(createAEAD(suite, trafficSecret))
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
@@ -10,25 +11,29 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockCipherSuite struct{}
|
||||
|
||||
var _ cipherSuite = &mockCipherSuite{}
|
||||
|
||||
func (c *mockCipherSuite) Hash() crypto.Hash { return crypto.SHA256 }
|
||||
func (c *mockCipherSuite) KeyLen() int { return 16 }
|
||||
func (c *mockCipherSuite) IVLen() int { return 12 }
|
||||
func (c *mockCipherSuite) AEAD(key, _ []byte) cipher.AEAD {
|
||||
block, err := aes.NewCipher(key)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return gcm
|
||||
}
|
||||
|
||||
var _ = Describe("Updatable AEAD", func() {
|
||||
getAEAD := func() *updatableAEAD {
|
||||
key := make([]byte, 16)
|
||||
hpKey := make([]byte, 16)
|
||||
rand.Read(key)
|
||||
rand.Read(hpKey)
|
||||
block, err := aes.NewCipher(key)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
aead, err := cipher.NewGCM(block)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hpBlock, err := aes.NewCipher(hpKey)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
iv := make([]byte, 12)
|
||||
rand.Read(iv)
|
||||
trafficSecret := make([]byte, 16)
|
||||
rand.Read(trafficSecret)
|
||||
|
||||
u := newUpdatableAEAD()
|
||||
u.SetReadKey(aead, hpBlock)
|
||||
u.SetWriteKey(aead, hpBlock)
|
||||
u.SetReadKey(&mockCipherSuite{}, trafficSecret)
|
||||
u.SetWriteKey(&mockCipherSuite{}, trafficSecret)
|
||||
return u
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user