perform a key update when receiving a packet with a different key phase

This commit is contained in:
Marten Seemann
2019-06-11 20:25:51 +08:00
parent e74ede678f
commit 1fb970cbac
2 changed files with 182 additions and 70 deletions

View File

@@ -1,15 +1,30 @@
package handshake
import (
"crypto"
"crypto/cipher"
"encoding/binary"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
)
type updatableAEAD struct {
openAEAD cipher.AEAD
sealAEAD cipher.AEAD
suite cipherSuite
keyPhase protocol.KeyPhase
prevRcvAEAD cipher.AEAD
firstRcvdWithCurrentKey protocol.PacketNumber
firstSentWithCurrentKey protocol.PacketNumber
rcvAEAD cipher.AEAD
sendAEAD cipher.AEAD
nextRcvAEAD cipher.AEAD
nextSendAEAD cipher.AEAD
nextRcvTrafficSecret []byte
nextSendTrafficSecret []byte
hpDecrypter cipher.Block
hpEncrypter cipher.Block
@@ -23,63 +38,100 @@ var _ ShortHeaderOpener = &updatableAEAD{}
var _ ShortHeaderSealer = &updatableAEAD{}
func newUpdatableAEAD() *updatableAEAD {
return &updatableAEAD{}
return &updatableAEAD{
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
}
}
func (a *updatableAEAD) rollKeys() {
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
a.keyPhase = a.keyPhase.Next()
a.prevRcvAEAD = a.rcvAEAD
a.rcvAEAD = a.nextRcvAEAD
a.sendAEAD = a.nextSendAEAD
a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash(), a.nextRcvTrafficSecret)
a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash(), a.nextSendTrafficSecret)
a.nextRcvAEAD, _ = createAEAD(a.suite, a.nextRcvTrafficSecret)
a.nextSendAEAD, _ = createAEAD(a.suite, a.nextSendTrafficSecret)
}
func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte {
return qtls.HkdfExpandLabel(hash, ts, []byte{}, "traffic upd", hash.Size())
}
// For the client, this function is called before SetWriteKey.
// For the server, this function is called after SetWriteKey.
func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) {
aead, hpDecrypter := createAEAD(suite, trafficSecret)
if len(a.nonceBuf) == 0 {
a.nonceBuf = make([]byte, aead.NonceSize())
} else if len(a.nonceBuf) != aead.NonceSize() {
panic("invalid nonce size")
a.rcvAEAD, a.hpDecrypter = createAEAD(suite, trafficSecret)
if a.suite == nil {
a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize())
a.hpMask = make([]byte, a.hpDecrypter.BlockSize())
a.suite = suite
}
a.openAEAD = aead
if len(a.hpMask) == 0 {
a.hpMask = make([]byte, hpDecrypter.BlockSize())
} else if len(a.hpMask) != hpDecrypter.BlockSize() {
panic("invalid header protection block size")
}
a.hpDecrypter = hpDecrypter
a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash(), trafficSecret)
a.nextRcvAEAD, _ = createAEAD(suite, a.nextRcvTrafficSecret)
}
// For the client, this function is called after SetReadKey.
// For the server, this function is called before SetWriteKey.
func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
aead, hpEncrypter := createAEAD(suite, trafficSecret)
if len(a.nonceBuf) == 0 {
a.nonceBuf = make([]byte, aead.NonceSize())
} else if len(a.nonceBuf) != aead.NonceSize() {
panic("invalid nonce size")
a.sendAEAD, a.hpEncrypter = createAEAD(suite, trafficSecret)
if a.suite == nil {
a.nonceBuf = make([]byte, a.sendAEAD.NonceSize())
a.hpMask = make([]byte, a.hpEncrypter.BlockSize())
a.suite = suite
}
a.sealAEAD = aead
if len(a.hpMask) == 0 {
a.hpMask = make([]byte, hpEncrypter.BlockSize())
} else if len(a.hpMask) != hpEncrypter.BlockSize() {
panic("invalid header protection block size")
}
a.hpEncrypter = hpEncrypter
a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash(), trafficSecret)
a.nextSendAEAD, _ = createAEAD(suite, a.nextSendTrafficSecret)
}
func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, _ protocol.KeyPhase, ad []byte) ([]byte, error) {
func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhase, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
if kp != a.keyPhase {
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
// TODO: check that prevRcv actually exists
// we updated the key, but the peer hasn't updated yet
return a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
}
// try opening the packet with the next key phase
dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err == nil {
// if opening succeeds, roll over to the next key phase
a.rollKeys()
a.firstRcvdWithCurrentKey = pn
}
return dec, err
}
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return a.openAEAD.Open(dst, a.nonceBuf, src, ad)
dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err == nil && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
a.firstRcvdWithCurrentKey = pn
}
return dec, err
}
func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
a.firstSentWithCurrentKey = pn
}
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return a.sealAEAD.Seal(dst, a.nonceBuf, src, ad)
return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
}
func (a *updatableAEAD) KeyPhase() protocol.KeyPhase {
return protocol.KeyPhaseOne
return a.keyPhase
}
func (a *updatableAEAD) Overhead() int {
return a.sealAEAD.Overhead()
return a.sendAEAD.Overhead()
}
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {