Files
quic-go/internal/handshake/updatable_aead_test.go
2025-11-14 04:04:40 +03:00

732 lines
26 KiB
Go

package handshake
import (
"crypto/rand"
"crypto/tls"
"fmt"
mrand "math/rand/v2"
"testing"
"time"
"git.geeks-team.ru/gr1ffon/quic-go/internal/monotime"
"git.geeks-team.ru/gr1ffon/quic-go/internal/protocol"
"git.geeks-team.ru/gr1ffon/quic-go/internal/qerr"
"git.geeks-team.ru/gr1ffon/quic-go/internal/utils"
"git.geeks-team.ru/gr1ffon/quic-go/qlog"
"git.geeks-team.ru/gr1ffon/quic-go/qlogwriter"
"git.geeks-team.ru/gr1ffon/quic-go/testutils/events"
"github.com/stretchr/testify/require"
)
const (
msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
ad = "Donec in velit neque."
)
func randomCipherSuite() *cipherSuite { return cipherSuites[mrand.IntN(len(cipherSuites))] }
func setupEndpoints(t *testing.T, serverRTTStats *utils.RTTStats) (client, server *updatableAEAD, serverEventRecorder *events.Recorder) {
cs := randomCipherSuite()
var eventRecorder events.Recorder
trafficSecret1 := make([]byte, 16)
trafficSecret2 := make([]byte, 16)
rand.Read(trafficSecret1)
rand.Read(trafficSecret2)
client = newUpdatableAEAD(utils.NewRTTStats(), nil, utils.DefaultLogger, protocol.Version1)
server = newUpdatableAEAD(serverRTTStats, &eventRecorder, utils.DefaultLogger, protocol.Version1)
client.SetReadKey(cs, trafficSecret2)
client.SetWriteKey(cs, trafficSecret1)
server.SetReadKey(cs, trafficSecret1)
server.SetWriteKey(cs, trafficSecret2)
return client, server, &eventRecorder
}
func bothSides(ev qlogwriter.Event) []qlogwriter.Event {
switch ev := ev.(type) {
case qlog.KeyDiscarded:
return []qlogwriter.Event{
qlog.KeyDiscarded{
KeyType: qlog.KeyTypeClient1RTT,
KeyPhase: ev.KeyPhase,
},
qlog.KeyDiscarded{
KeyType: qlog.KeyTypeServer1RTT,
KeyPhase: ev.KeyPhase,
},
}
case qlog.KeyUpdated:
return []qlogwriter.Event{
qlog.KeyUpdated{
KeyType: qlog.KeyTypeClient1RTT,
KeyPhase: ev.KeyPhase,
Trigger: ev.Trigger,
},
qlog.KeyUpdated{
KeyType: qlog.KeyTypeServer1RTT,
KeyPhase: ev.KeyPhase,
Trigger: ev.Trigger,
},
}
default:
panic("unexpected event type: " + ev.Name())
}
}
func TestChaChaTestVector(t *testing.T) {
testCases := []struct {
name string
version protocol.Version
expectedPayload []byte
expectedPacket []byte
}{
{
version: protocol.Version1,
expectedPayload: splitHexString(t, "655e5cd55c41f69080575d7999c25a5bfb"),
expectedPacket: splitHexString(t, "4cfe4189655e5cd55c41f69080575d7999c25a5bfb"),
},
{
version: protocol.Version2,
expectedPayload: splitHexString(t, "0ae7b6b932bc27d786f4bc2bb20f2162ba"),
expectedPacket: splitHexString(t, "5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba"),
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("QUIC %s", tc.version), func(t *testing.T) {
secret := splitHexString(t, "9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b")
aead := newUpdatableAEAD(utils.NewRTTStats(), nil, nil, tc.version)
chacha := cipherSuites[2]
require.Equal(t, tls.TLS_CHACHA20_POLY1305_SHA256, chacha.ID)
aead.SetWriteKey(chacha, secret)
const pnOffset = 1
header := splitHexString(t, "4200bff4")
payloadOffset := len(header)
plaintext := splitHexString(t, "01")
payload := aead.Seal(nil, plaintext, 654360564, header)
require.Equal(t, tc.expectedPayload, payload)
packet := append(header, payload...)
aead.EncryptHeader(packet[pnOffset+4:pnOffset+4+16], &packet[0], packet[pnOffset:payloadOffset])
require.Equal(t, tc.expectedPacket, packet)
})
}
}
func TestUpdatableAEADHeaderProtection(t *testing.T) {
for _, v := range []protocol.Version{protocol.Version1, protocol.Version2} {
for _, cs := range cipherSuites {
t.Run(fmt.Sprintf("QUIC %s/%s", v, tls.CipherSuiteName(cs.ID)), func(t *testing.T) {
trafficSecret1 := make([]byte, 16)
trafficSecret2 := make([]byte, 16)
rand.Read(trafficSecret1)
rand.Read(trafficSecret2)
client := newUpdatableAEAD(utils.NewRTTStats(), nil, utils.DefaultLogger, v)
server := newUpdatableAEAD(utils.NewRTTStats(), nil, utils.DefaultLogger, v)
client.SetReadKey(cs, trafficSecret2)
client.SetWriteKey(cs, trafficSecret1)
server.SetReadKey(cs, trafficSecret1)
server.SetWriteKey(cs, trafficSecret2)
var lastFiveBitsDifferent int
for i := 0; i < 100; i++ {
sample := make([]byte, 16)
rand.Read(sample)
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
client.EncryptHeader(sample, &header[0], header[9:13])
if header[0]&0x1f != 0xb5&0x1f {
lastFiveBitsDifferent++
}
require.Equal(t, byte(0xb5&0xe0), header[0]&0xe0)
require.Equal(t, []byte{1, 2, 3, 4, 5, 6, 7, 8}, header[1:9])
require.NotEqual(t, []byte{0xde, 0xad, 0xbe, 0xef}, header[9:13])
server.DecryptHeader(sample, &header[0], header[9:13])
require.Equal(t, []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}, header)
}
require.Greater(t, lastFiveBitsDifferent, 75)
})
}
}
}
func TestUpdatableAEADEncryptDecryptMessage(t *testing.T) {
for _, v := range []protocol.Version{protocol.Version1, protocol.Version2} {
for _, cs := range cipherSuites {
t.Run(fmt.Sprintf("QUIC %s/%s", v, tls.CipherSuiteName(cs.ID)), func(t *testing.T) {
rttStats := utils.RTTStats{}
trafficSecret1 := make([]byte, 16)
trafficSecret2 := make([]byte, 16)
rand.Read(trafficSecret1)
rand.Read(trafficSecret2)
client := newUpdatableAEAD(&rttStats, nil, utils.DefaultLogger, v)
server := newUpdatableAEAD(&rttStats, nil, utils.DefaultLogger, v)
client.SetReadKey(cs, trafficSecret2)
client.SetWriteKey(cs, trafficSecret1)
server.SetReadKey(cs, trafficSecret1)
server.SetWriteKey(cs, trafficSecret2)
msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
ad := []byte("Donec in velit neque.")
encrypted := server.Seal(nil, msg, 0x1337, ad)
opened, err := client.Open(nil, encrypted, monotime.Now(), 0x1337, protocol.KeyPhaseZero, ad)
require.NoError(t, err)
require.Equal(t, msg, opened)
_, err = client.Open(nil, encrypted, monotime.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad"))
require.Equal(t, ErrDecryptionFailed, err)
_, err = client.Open(nil, encrypted, monotime.Now(), 0x42, protocol.KeyPhaseZero, ad)
require.Equal(t, ErrDecryptionFailed, err)
})
}
}
}
func TestUpdatableAEADPacketNumbers(t *testing.T) {
client, server, _ := setupEndpoints(t, utils.NewRTTStats())
msg := []byte("Lorem ipsum")
ad := []byte("Donec in velit neque.")
encrypted := server.Seal(nil, msg, 0x1337, ad)
require.Equal(t, protocol.PacketNumber(0x1337), server.FirstPacketNumber()) // make sure we save the first packet number
_ = server.Seal(nil, msg, 0x1338, ad)
require.Equal(t, protocol.PacketNumber(0x1337), server.FirstPacketNumber()) // make sure we save the first packet number
// check that decoding the packet number works as expected
_, err := client.Open(nil, encrypted[:len(encrypted)-1], monotime.Now(), 0x1337, protocol.KeyPhaseZero, ad)
require.Error(t, err)
require.Equal(t, protocol.PacketNumber(0x38), client.DecodePacketNumber(0x38, protocol.PacketNumberLen1))
_, err = client.Open(nil, encrypted, monotime.Now(), 0x1337, protocol.KeyPhaseZero, ad)
require.NoError(t, err)
require.Equal(t, protocol.PacketNumber(0x1338), client.DecodePacketNumber(0x38, protocol.PacketNumberLen1))
}
func TestAEADLimitReached(t *testing.T) {
client, _, _ := setupEndpoints(t, utils.NewRTTStats())
client.invalidPacketLimit = 10
for i := 0; i < 9; i++ {
_, err := client.Open(nil, []byte("foobar"), monotime.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad"))
require.Equal(t, ErrDecryptionFailed, err)
}
_, err := client.Open(nil, []byte("foobar"), monotime.Now(), 10, protocol.KeyPhaseZero, []byte("ad"))
require.Error(t, err)
var transportErr *qerr.TransportError
require.ErrorAs(t, err, &transportErr)
require.Equal(t, qerr.AEADLimitReached, transportErr.ErrorCode)
}
func TestKeyUpdates(t *testing.T) {
client, server, _ := setupEndpoints(t, utils.NewRTTStats())
now := monotime.Now()
require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
encrypted0 := server.Seal(nil, []byte(msg), 0x1337, []byte(ad))
server.rollKeys()
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
encrypted1 := server.Seal(nil, []byte(msg), 0x1337, []byte(ad))
require.NotEqual(t, encrypted0, encrypted1)
_, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, []byte(ad))
require.Equal(t, ErrDecryptionFailed, err)
client.rollKeys()
decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, []byte(ad))
require.NoError(t, err)
require.Equal(t, msg, string(decrypted))
}
// func TestUpdatesKeysWhenReceivingPacketWithNextKeyPhase(t *testing.T) {
// rttStats := utils.RTTStats{}
// mockCtrl := gomock.NewController(t)
// serverTracer := mocklogging.NewMockConnectionTracer(mockCtrl)
// trafficSecret1 := make([]byte, 16)
// trafficSecret2 := make([]byte, 16)
// rand.Read(trafficSecret1)
// rand.Read(trafficSecret2)
// client := newUpdatableAEAD(&rttStats, nil, utils.DefaultLogger, protocol.Version1)
// server := newUpdatableAEAD(&rttStats, serverTracer, utils.DefaultLogger, protocol.Version1)
// client.SetReadKey(cs, trafficSecret2)
// client.SetWriteKey(cs, trafficSecret1)
// server.SetReadKey(cs, trafficSecret1)
// server.SetWriteKey(cs, trafficSecret2)
// now := monotime.Now()
// encrypted0 := client.Seal(nil, []byte(msg), 0x42, ad)
// decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad)
// require.NoError(t, err)
// require.Equal(t, msg, decrypted)
// require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
// _ = server.Seal(nil, msg, 0x1, ad)
// client.rollKeys()
// encrypted1 := client.Seal(nil, msg, 0x43, ad)
// serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
// decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad)
// require.NoError(t, err)
// require.Equal(t, msg, decrypted)
// require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
// }
func TestReorderedPacketAfterKeyUpdate(t *testing.T) {
client, server, eventRecorder := setupEndpoints(t, utils.NewRTTStats())
now := monotime.Now()
encrypted01 := client.Seal(nil, []byte(msg), 0x42, []byte(ad))
encrypted02 := client.Seal(nil, []byte(msg), 0x43, []byte(ad))
_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, []byte(ad))
require.NoError(t, err)
_ = server.Seal(nil, []byte(msg), 0x1, []byte(ad))
client.rollKeys()
encrypted1 := client.Seal(nil, []byte(msg), 0x44, []byte(ad))
require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, []byte(ad))
require.NoError(t, err)
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
require.Equal(t,
bothSides(qlog.KeyUpdated{Trigger: qlog.KeyUpdateRemote, KeyPhase: 1}),
eventRecorder.Events(),
)
// now receive a reordered packet
decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, []byte(ad))
require.NoError(t, err)
require.Equal(t, msg, string(decrypted))
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
}
func TestDropsKeys3PTOsAfterKeyUpdate(t *testing.T) {
rttStats := utils.NewRTTStats()
client, server, eventRecorder := setupEndpoints(t, rttStats)
now := monotime.Now()
rttStats.UpdateRTT(10*time.Millisecond, 0)
pto := rttStats.PTO(true)
encrypted01 := client.Seal(nil, []byte(msg), 0x42, []byte(ad))
encrypted02 := client.Seal(nil, []byte(msg), 0x43, []byte(ad))
_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, []byte(ad))
require.NoError(t, err)
_ = server.Seal(nil, []byte(msg), 0x1, []byte(ad))
client.rollKeys()
encrypted1 := client.Seal(nil, []byte(msg), 0x44, []byte(ad))
require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, []byte(ad))
require.NoError(t, err)
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
require.Equal(t,
bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateRemote}),
eventRecorder.Events(),
)
eventRecorder.Clear()
// packet arrived too late, the key was already dropped
_, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, []byte(ad))
require.Equal(t, ErrKeysDropped, err)
require.Equal(t,
bothSides(qlog.KeyDiscarded{KeyPhase: 0}),
eventRecorder.Events(),
)
}
func TestAllowsFirstKeyUpdateImmediately(t *testing.T) {
client, server, serverTracer := setupEndpoints(t, utils.NewRTTStats())
client.rollKeys()
encrypted := client.Seal(nil, []byte(msg), 0x1337, []byte(ad))
// if decryption failed, we don't expect a key phase update
_, err := server.Open(nil, encrypted[:len(encrypted)-1], monotime.Now(), 0x1337, protocol.KeyPhaseOne, []byte(ad))
require.Equal(t, ErrDecryptionFailed, err)
// the key phase is updated on first successful decryption
_, err = server.Open(nil, encrypted, monotime.Now(), 0x1337, protocol.KeyPhaseOne, []byte(ad))
require.NoError(t, err)
require.Equal(t,
bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateRemote}),
serverTracer.Events(),
)
}
func TestRejectFrequentKeyUpdates(t *testing.T) {
client, server, _ := setupEndpoints(t, utils.NewRTTStats())
server.rollKeys()
client.rollKeys()
encrypted0 := client.Seal(nil, []byte(msg), 0x42, []byte(ad))
_, err := server.Open(nil, encrypted0, monotime.Now(), 0x42, protocol.KeyPhaseOne, []byte(ad))
require.NoError(t, err)
client.rollKeys()
encrypted1 := client.Seal(nil, []byte(msg), 0x42, []byte(ad))
_, err = server.Open(nil, encrypted1, monotime.Now(), 0x42, protocol.KeyPhaseZero, []byte(ad))
require.Equal(t, &qerr.TransportError{
ErrorCode: qerr.KeyUpdateError,
ErrorMessage: "keys updated too quickly",
}, err)
}
func setKeyUpdateIntervals(t *testing.T, firstKeyUpdateInterval, keyUpdateInterval uint64) {
reset := SetKeyUpdateInterval(keyUpdateInterval)
t.Cleanup(reset)
origFirstKeyUpdateInterval := FirstKeyUpdateInterval
FirstKeyUpdateInterval = firstKeyUpdateInterval
t.Cleanup(func() { FirstKeyUpdateInterval = origFirstKeyUpdateInterval })
}
func TestInitiateKeyUpdateAfterSendingMaxPackets(t *testing.T) {
const firstKeyUpdateInterval = 5
const keyUpdateInterval = 20
setKeyUpdateIntervals(t, firstKeyUpdateInterval, keyUpdateInterval)
client, server, eventRecorder := setupEndpoints(t, utils.NewRTTStats())
server.SetHandshakeConfirmed()
var pn protocol.PacketNumber
// first key update
for range firstKeyUpdateInterval {
require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
server.Seal(nil, []byte(msg), pn, []byte(ad))
pn++
}
// the first update is allowed without receiving an acknowledgement
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
require.Equal(t,
bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateLocal}),
eventRecorder.Events(),
)
eventRecorder.Clear()
// subsequent key update
for range 2 * keyUpdateInterval {
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
server.Seal(nil, []byte(msg), pn, []byte(ad))
pn++
}
// no update allowed before receiving an acknowledgement for the current key phase
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
// receive an ACK for a packet sent in key phase 1
client.rollKeys()
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
_, err := server.Open(nil, b, monotime.Now(), 1, protocol.KeyPhaseOne, []byte("ad"))
require.NoError(t, err)
require.NoError(t, server.SetLargestAcked(firstKeyUpdateInterval))
require.Empty(t, eventRecorder.Events())
require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
require.Equal(t,
append(
bothSides(qlog.KeyDiscarded{KeyPhase: 0}),
bothSides(qlog.KeyUpdated{KeyPhase: 2, Trigger: qlog.KeyUpdateLocal})...,
),
eventRecorder.Events(),
)
}
func TestKeyUpdateEnforceACKKeyPhase(t *testing.T) {
const firstKeyUpdateInterval = 5
setKeyUpdateIntervals(t, firstKeyUpdateInterval, protocol.KeyUpdateInterval)
_, server, eventRecorder := setupEndpoints(t, utils.NewRTTStats())
server.SetHandshakeConfirmed()
// First make sure that we update our keys.
for i := range firstKeyUpdateInterval {
pn := protocol.PacketNumber(i)
require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
server.Seal(nil, []byte(msg), pn, []byte(ad))
}
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
require.Equal(t,
bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateLocal}),
eventRecorder.Events(),
)
eventRecorder.Clear()
// Now that our keys are updated, send a packet using the new keys.
const nextPN = firstKeyUpdateInterval + 1
server.Seal(nil, []byte(msg), nextPN, []byte(ad))
for i := range firstKeyUpdateInterval {
// We haven't decrypted any packet in the new key phase yet.
// This means that the ACK must have been sent in the old key phase.
require.NoError(t, server.SetLargestAcked(protocol.PacketNumber(i)))
}
// We haven't decrypted any packet in the new key phase yet.
// This means that the ACK must have been sent in the old key phase.
err := server.SetLargestAcked(nextPN)
require.Error(t, err)
var transportErr *qerr.TransportError
require.ErrorAs(t, err, &transportErr)
require.Equal(t, qerr.KeyUpdateError, transportErr.ErrorCode)
require.Equal(t, "received ACK for key phase 1, but peer didn't update keys", transportErr.ErrorMessage)
require.Empty(t, eventRecorder.Events())
}
func TestKeyUpdateAfterOpeningMaxPackets(t *testing.T) {
const firstKeyUpdateInterval = 5
const keyUpdateInterval = 20
setKeyUpdateIntervals(t, firstKeyUpdateInterval, keyUpdateInterval)
client, server, eventRecorder := setupEndpoints(t, utils.NewRTTStats())
server.SetHandshakeConfirmed()
msg := []byte("message")
ad := []byte("additional data")
// first key update
var pn protocol.PacketNumber
for range firstKeyUpdateInterval {
require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
encrypted := client.Seal(nil, msg, pn, ad)
_, err := server.Open(nil, encrypted, monotime.Now(), pn, protocol.KeyPhaseZero, ad)
require.NoError(t, err)
pn++
}
// the first update is allowed without receiving an acknowledgement
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
require.Equal(t,
bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateLocal}),
eventRecorder.Events(),
)
eventRecorder.Clear()
// subsequent key update
client.rollKeys()
for range keyUpdateInterval {
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
encrypted := client.Seal(nil, msg, pn, ad)
_, err := server.Open(nil, encrypted, monotime.Now(), pn, protocol.KeyPhaseOne, ad)
require.NoError(t, err)
pn++
}
// No update allowed before receiving an acknowledgement for the current key phase
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
server.Seal(nil, msg, 1, ad)
require.NoError(t, server.SetLargestAcked(firstKeyUpdateInterval+1))
require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
require.Equal(t,
append(
bothSides(qlog.KeyDiscarded{KeyPhase: 0}),
bothSides(qlog.KeyUpdated{KeyPhase: 2, Trigger: qlog.KeyUpdateLocal})...,
),
eventRecorder.Events(),
)
}
func TestKeyUpdateKeyPhaseSkipping(t *testing.T) {
const firstKeyUpdateInterval = 5
const keyUpdateInterval = 20
setKeyUpdateIntervals(t, firstKeyUpdateInterval, keyUpdateInterval)
rttStats := utils.NewRTTStats()
rttStats.UpdateRTT(10*time.Millisecond, 0)
client, server, eventRecorder := setupEndpoints(t, rttStats)
server.SetHandshakeConfirmed()
now := monotime.Now()
data1 := client.Seal(nil, []byte(msg), 1, []byte(ad))
_, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, []byte(ad))
require.NoError(t, err)
for i := range firstKeyUpdateInterval {
pn := protocol.PacketNumber(i)
require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
server.Seal(nil, []byte(msg), pn, []byte(ad))
require.NoError(t, server.SetLargestAcked(pn))
}
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
require.Equal(t,
bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateLocal}),
eventRecorder.Events(),
)
eventRecorder.Clear()
// The server never received a packet at key phase 1.
// Make sure the key phase 0 is still there at a much later point.
data2 := client.Seal(nil, []byte(msg), 1, []byte(ad))
_, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, []byte(ad))
require.NoError(t, err)
require.Empty(t, eventRecorder.Events())
}
func TestFastKeyUpdatesByPeer(t *testing.T) {
const firstKeyUpdateInterval = 5
const keyUpdateInterval = 20
setKeyUpdateIntervals(t, firstKeyUpdateInterval, keyUpdateInterval)
client, server, eventRecorder := setupEndpoints(t, utils.NewRTTStats())
server.SetHandshakeConfirmed()
var pn protocol.PacketNumber
for range firstKeyUpdateInterval {
require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
server.Seal(nil, []byte(msg), pn, []byte(ad))
pn++
}
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
_, err := server.Open(nil, b, monotime.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
require.NoError(t, err)
require.NoError(t, server.SetLargestAcked(0))
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
require.Equal(t,
bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateLocal}),
eventRecorder.Events(),
)
eventRecorder.Clear()
// Send and receive an acknowledgement for a packet in key phase 1.
// We are now running a timer to drop the keys with 3 PTO.
server.Seal(nil, []byte(msg), pn, []byte(ad))
client.rollKeys()
dataKeyPhaseOne := client.Seal(nil, []byte(msg), 2, []byte(ad))
now := monotime.Now()
_, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, []byte(ad))
require.NoError(t, err)
require.NoError(t, server.SetLargestAcked(pn))
// Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over.
// This mean that we need to drop the keys for key phase 0 immediately.
client.rollKeys()
dataKeyPhaseTwo := client.Seal(nil, []byte(msg), 3, []byte(ad))
_, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, []byte(ad))
require.NoError(t, err)
require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
require.Equal(t,
append(
bothSides(qlog.KeyDiscarded{KeyPhase: 0}),
bothSides(qlog.KeyUpdated{KeyPhase: 2, Trigger: qlog.KeyUpdateRemote})...,
),
eventRecorder.Events(),
)
}
func TestFastKeyUpdateByUs(t *testing.T) {
const firstKeyUpdateInterval = 5
const keyUpdateInterval = 20
setKeyUpdateIntervals(t, firstKeyUpdateInterval, keyUpdateInterval)
rttStats := utils.NewRTTStats()
rttStats.UpdateRTT(10*time.Millisecond, 0)
client, server, eventRecorder := setupEndpoints(t, rttStats)
server.SetHandshakeConfirmed()
// send so many packets that we initiate the first key update
for i := range firstKeyUpdateInterval {
pn := protocol.PacketNumber(i)
require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
server.Seal(nil, []byte(msg), pn, []byte(ad))
}
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
_, err := server.Open(nil, b, monotime.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
require.NoError(t, err)
require.NoError(t, server.SetLargestAcked(0))
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
require.Equal(t,
bothSides(qlog.KeyUpdated{KeyPhase: 1, Trigger: qlog.KeyUpdateLocal}),
eventRecorder.Events(),
)
eventRecorder.Clear()
// send so many packets that we initiate the next key update
for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
require.Equal(t, protocol.KeyPhaseOne, server.KeyPhase())
server.Seal(nil, []byte(msg), pn, []byte(ad))
}
client.rollKeys()
b = client.Seal(nil, []byte("foobar"), 2, []byte("ad"))
now := monotime.Now()
_, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad"))
require.NoError(t, err)
require.NoError(t, server.SetLargestAcked(keyUpdateInterval))
require.Equal(t, protocol.KeyPhaseZero, server.KeyPhase())
require.Equal(t,
append(
bothSides(qlog.KeyDiscarded{KeyPhase: 0}),
bothSides(qlog.KeyUpdated{KeyPhase: 2, Trigger: qlog.KeyUpdateLocal})...,
),
eventRecorder.Events(),
)
eventRecorder.Clear()
// We haven't received an ACK for a packet sent in key phase 2 yet.
// Make sure we canceled the timer to drop the previous key phase.
b = client.Seal(nil, []byte("foobar"), 3, []byte("ad"))
_, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad"))
require.NoError(t, err)
require.Empty(t, eventRecorder.Events())
}
func getClientAndServer() (client, server *updatableAEAD) {
trafficSecret1 := make([]byte, 16)
trafficSecret2 := make([]byte, 16)
rand.Read(trafficSecret1)
rand.Read(trafficSecret2)
cs := cipherSuites[0]
rttStats := utils.NewRTTStats()
client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1)
server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1)
client.SetReadKey(cs, trafficSecret2)
client.SetWriteKey(cs, trafficSecret1)
server.SetReadKey(cs, trafficSecret1)
server.SetWriteKey(cs, trafficSecret2)
return
}
func BenchmarkPacketEncryption(b *testing.B) {
client, _ := getClientAndServer()
const l = 1200
src := make([]byte, l)
rand.Read(src)
ad := make([]byte, 32)
rand.Read(ad)
var pn protocol.PacketNumber
for b.Loop() {
src = client.Seal(src[:0], src[:l], pn, ad)
pn++
}
}
func BenchmarkPacketDecryption(b *testing.B) {
client, server := getClientAndServer()
const l = 1200
src := make([]byte, l)
dst := make([]byte, l)
rand.Read(src)
ad := make([]byte, 32)
rand.Read(ad)
src = client.Seal(src[:0], src[:l], 1337, ad)
for b.Loop() {
if _, err := server.Open(dst[:0], src, 0, 1337, protocol.KeyPhaseZero, ad); err != nil {
b.Fatalf("opening failed: %v", err)
}
}
}
func BenchmarkRollKeys(b *testing.B) {
client, _ := getClientAndServer()
for b.Loop() {
client.rollKeys()
}
if int(client.keyPhase) != b.N {
b.Fatal("didn't roll keys often enough")
}
}