forked from quic-go/quic-go
perform a key update when receiving a packet with a different key phase
This commit is contained in:
@@ -1,15 +1,30 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
type updatableAEAD struct {
|
||||
openAEAD cipher.AEAD
|
||||
sealAEAD cipher.AEAD
|
||||
suite cipherSuite
|
||||
|
||||
keyPhase protocol.KeyPhase
|
||||
|
||||
prevRcvAEAD cipher.AEAD
|
||||
|
||||
firstRcvdWithCurrentKey protocol.PacketNumber
|
||||
firstSentWithCurrentKey protocol.PacketNumber
|
||||
rcvAEAD cipher.AEAD
|
||||
sendAEAD cipher.AEAD
|
||||
|
||||
nextRcvAEAD cipher.AEAD
|
||||
nextSendAEAD cipher.AEAD
|
||||
nextRcvTrafficSecret []byte
|
||||
nextSendTrafficSecret []byte
|
||||
|
||||
hpDecrypter cipher.Block
|
||||
hpEncrypter cipher.Block
|
||||
@@ -23,63 +38,100 @@ var _ ShortHeaderOpener = &updatableAEAD{}
|
||||
var _ ShortHeaderSealer = &updatableAEAD{}
|
||||
|
||||
func newUpdatableAEAD() *updatableAEAD {
|
||||
return &updatableAEAD{}
|
||||
return &updatableAEAD{
|
||||
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
|
||||
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) rollKeys() {
|
||||
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
|
||||
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
|
||||
a.keyPhase = a.keyPhase.Next()
|
||||
a.prevRcvAEAD = a.rcvAEAD
|
||||
a.rcvAEAD = a.nextRcvAEAD
|
||||
a.sendAEAD = a.nextSendAEAD
|
||||
|
||||
a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash(), a.nextRcvTrafficSecret)
|
||||
a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash(), a.nextSendTrafficSecret)
|
||||
a.nextRcvAEAD, _ = createAEAD(a.suite, a.nextRcvTrafficSecret)
|
||||
a.nextSendAEAD, _ = createAEAD(a.suite, a.nextSendTrafficSecret)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte {
|
||||
return qtls.HkdfExpandLabel(hash, ts, []byte{}, "traffic upd", hash.Size())
|
||||
}
|
||||
|
||||
// For the client, this function is called before SetWriteKey.
|
||||
// For the server, this function is called after SetWriteKey.
|
||||
func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) {
|
||||
aead, hpDecrypter := createAEAD(suite, trafficSecret)
|
||||
if len(a.nonceBuf) == 0 {
|
||||
a.nonceBuf = make([]byte, aead.NonceSize())
|
||||
} else if len(a.nonceBuf) != aead.NonceSize() {
|
||||
panic("invalid nonce size")
|
||||
a.rcvAEAD, a.hpDecrypter = createAEAD(suite, trafficSecret)
|
||||
if a.suite == nil {
|
||||
a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize())
|
||||
a.hpMask = make([]byte, a.hpDecrypter.BlockSize())
|
||||
a.suite = suite
|
||||
}
|
||||
a.openAEAD = aead
|
||||
|
||||
if len(a.hpMask) == 0 {
|
||||
a.hpMask = make([]byte, hpDecrypter.BlockSize())
|
||||
} else if len(a.hpMask) != hpDecrypter.BlockSize() {
|
||||
panic("invalid header protection block size")
|
||||
}
|
||||
a.hpDecrypter = hpDecrypter
|
||||
a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash(), trafficSecret)
|
||||
a.nextRcvAEAD, _ = createAEAD(suite, a.nextRcvTrafficSecret)
|
||||
}
|
||||
|
||||
// For the client, this function is called after SetReadKey.
|
||||
// For the server, this function is called before SetWriteKey.
|
||||
func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
|
||||
aead, hpEncrypter := createAEAD(suite, trafficSecret)
|
||||
if len(a.nonceBuf) == 0 {
|
||||
a.nonceBuf = make([]byte, aead.NonceSize())
|
||||
} else if len(a.nonceBuf) != aead.NonceSize() {
|
||||
panic("invalid nonce size")
|
||||
a.sendAEAD, a.hpEncrypter = createAEAD(suite, trafficSecret)
|
||||
if a.suite == nil {
|
||||
a.nonceBuf = make([]byte, a.sendAEAD.NonceSize())
|
||||
a.hpMask = make([]byte, a.hpEncrypter.BlockSize())
|
||||
a.suite = suite
|
||||
}
|
||||
a.sealAEAD = aead
|
||||
|
||||
if len(a.hpMask) == 0 {
|
||||
a.hpMask = make([]byte, hpEncrypter.BlockSize())
|
||||
} else if len(a.hpMask) != hpEncrypter.BlockSize() {
|
||||
panic("invalid header protection block size")
|
||||
}
|
||||
a.hpEncrypter = hpEncrypter
|
||||
a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash(), trafficSecret)
|
||||
a.nextSendAEAD, _ = createAEAD(suite, a.nextSendTrafficSecret)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, _ protocol.KeyPhase, ad []byte) ([]byte, error) {
|
||||
func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhase, ad []byte) ([]byte, error) {
|
||||
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
||||
if kp != a.keyPhase {
|
||||
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
|
||||
// TODO: check that prevRcv actually exists
|
||||
// we updated the key, but the peer hasn't updated yet
|
||||
return a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
}
|
||||
// try opening the packet with the next key phase
|
||||
dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
if err == nil {
|
||||
// if opening succeeds, roll over to the next key phase
|
||||
a.rollKeys()
|
||||
a.firstRcvdWithCurrentKey = pn
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
return a.openAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
if err == nil && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
|
||||
a.firstRcvdWithCurrentKey = pn
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
||||
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
|
||||
a.firstSentWithCurrentKey = pn
|
||||
}
|
||||
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
return a.sealAEAD.Seal(dst, a.nonceBuf, src, ad)
|
||||
return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) KeyPhase() protocol.KeyPhase {
|
||||
return protocol.KeyPhaseOne
|
||||
return a.keyPhase
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) Overhead() int {
|
||||
return a.sealAEAD.Overhead()
|
||||
return a.sendAEAD.Overhead()
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
|
||||
@@ -27,62 +27,122 @@ func (c *mockCipherSuite) AEAD(key, _ []byte) cipher.AEAD {
|
||||
}
|
||||
|
||||
var _ = Describe("Updatable AEAD", func() {
|
||||
getAEAD := func() *updatableAEAD {
|
||||
trafficSecret := make([]byte, 16)
|
||||
rand.Read(trafficSecret)
|
||||
getPeers := func() (client, server *updatableAEAD) {
|
||||
trafficSecret1 := make([]byte, 16)
|
||||
trafficSecret2 := make([]byte, 16)
|
||||
rand.Read(trafficSecret1)
|
||||
rand.Read(trafficSecret2)
|
||||
|
||||
u := newUpdatableAEAD()
|
||||
u.SetReadKey(&mockCipherSuite{}, trafficSecret)
|
||||
u.SetWriteKey(&mockCipherSuite{}, trafficSecret)
|
||||
return u
|
||||
client = newUpdatableAEAD()
|
||||
server = newUpdatableAEAD()
|
||||
client.SetReadKey(&mockCipherSuite{}, trafficSecret2)
|
||||
client.SetWriteKey(&mockCipherSuite{}, trafficSecret1)
|
||||
server.SetReadKey(&mockCipherSuite{}, trafficSecret1)
|
||||
server.SetWriteKey(&mockCipherSuite{}, trafficSecret2)
|
||||
return
|
||||
}
|
||||
|
||||
Context("message encryption", func() {
|
||||
msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||
ad := []byte("Donec in velit neque.")
|
||||
|
||||
It("encrypts and decrypts a message", func() {
|
||||
aead := getAEAD()
|
||||
encrypted := aead.Seal(nil, msg, 0x1337, ad)
|
||||
opened, err := aead.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(opened).To(Equal(msg))
|
||||
})
|
||||
|
||||
It("fails to open a message if the associated data is not the same", func() {
|
||||
aead := getAEAD()
|
||||
encrypted := aead.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := aead.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, []byte("wrong ad"))
|
||||
Expect(err).To(MatchError("cipher: message authentication failed"))
|
||||
})
|
||||
|
||||
It("fails to open a message if the packet number is not the same", func() {
|
||||
aead := getAEAD()
|
||||
encrypted := aead.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := aead.Open(nil, encrypted, 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).To(MatchError("cipher: message authentication failed"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("header encryption", func() {
|
||||
Context("header protection", func() {
|
||||
It("encrypts and decrypts the header", func() {
|
||||
aead := getAEAD()
|
||||
server, client := getPeers()
|
||||
var lastFiveBitsDifferent int
|
||||
for i := 0; i < 100; i++ {
|
||||
sample := make([]byte, 16)
|
||||
rand.Read(sample)
|
||||
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
|
||||
aead.EncryptHeader(sample, &header[0], header[9:13])
|
||||
client.EncryptHeader(sample, &header[0], header[9:13])
|
||||
if header[0]&0x1f != 0xb5&0x1f {
|
||||
lastFiveBitsDifferent++
|
||||
}
|
||||
Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0)))
|
||||
Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
|
||||
Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
|
||||
aead.DecryptHeader(sample, &header[0], header[9:13])
|
||||
server.DecryptHeader(sample, &header[0], header[9:13])
|
||||
Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}))
|
||||
}
|
||||
Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75))
|
||||
})
|
||||
})
|
||||
|
||||
Context("message encryption", func() {
|
||||
var msg, ad []byte
|
||||
var server, client *updatableAEAD
|
||||
|
||||
BeforeEach(func() {
|
||||
server, client = getPeers()
|
||||
msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||
ad = []byte("Donec in velit neque.")
|
||||
})
|
||||
|
||||
It("encrypts and decrypts a message", func() {
|
||||
encrypted := server.Seal(nil, msg, 0x1337, ad)
|
||||
opened, err := client.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(opened).To(Equal(msg))
|
||||
})
|
||||
|
||||
It("fails to open a message if the associated data is not the same", func() {
|
||||
encrypted := client.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := server.Open(nil, encrypted, 0x1337, protocol.KeyPhaseZero, []byte("wrong ad"))
|
||||
Expect(err).To(MatchError("cipher: message authentication failed"))
|
||||
})
|
||||
|
||||
It("fails to open a message if the packet number is not the same", func() {
|
||||
encrypted := server.Seal(nil, msg, 0x1337, ad)
|
||||
_, err := client.Open(nil, encrypted, 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).To(MatchError("cipher: message authentication failed"))
|
||||
})
|
||||
|
||||
Context("key updates", func() {
|
||||
It("updates keys", func() {
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
encrypted0 := server.Seal(nil, msg, 0x1337, ad)
|
||||
server.rollKeys()
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
encrypted1 := server.Seal(nil, msg, 0x1337, ad)
|
||||
Expect(encrypted0).ToNot(Equal(encrypted1))
|
||||
// expect opening to fail. The client didn't roll keys yet
|
||||
_, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).To(MatchError("cipher: message authentication failed"))
|
||||
client.rollKeys()
|
||||
decrypted, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decrypted).To(Equal(msg))
|
||||
})
|
||||
|
||||
It("updates the keys when receiving a packet with the next key phase", func() {
|
||||
encrypted0 := client.Seal(nil, msg, 0x42, ad)
|
||||
decrypted, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decrypted).To(Equal(msg))
|
||||
client.rollKeys()
|
||||
encrypted1 := client.Seal(nil, msg, 0x43, ad)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
decrypted, err = server.Open(nil, encrypted1, 0x43, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decrypted).To(Equal(msg))
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
})
|
||||
|
||||
It("opens a reordered packet with the old keys after an update", func() {
|
||||
encrypted01 := client.Seal(nil, msg, 0x42, ad)
|
||||
encrypted02 := client.Seal(nil, msg, 0x43, ad)
|
||||
// receive the first packet with key phase 0
|
||||
_, err := server.Open(nil, encrypted01, 0x42, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// now receive a packet with key phase 1
|
||||
client.rollKeys()
|
||||
encrypted1 := client.Seal(nil, msg, 0x44, ad)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
_, err = server.Open(nil, encrypted1, 0x44, protocol.KeyPhaseOne, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
// now receive a reordered packet with key phase 0
|
||||
decrypted, err := server.Open(nil, encrypted02, 0x43, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decrypted).To(Equal(msg))
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user