Merge pull request #1929 from lucas-clemente/drop-handshake-keys

drop Initial and Handshake keys when receiving the first 1-RTT ACK
This commit is contained in:
Marten Seemann
2019-05-30 20:52:39 +08:00
committed by GitHub
13 changed files with 221 additions and 76 deletions

View File

@@ -14,7 +14,7 @@ type SentPacketHandler interface {
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
SetMaxAckDelay(time.Duration)
SetHandshakeComplete()
DropPackets(protocol.EncryptionLevel)
ResetForRetry() error
// The SendMode determines if and what kind of packets can be sent.
@@ -45,6 +45,7 @@ type SentPacketHandler interface {
type ReceivedPacketHandler interface {
ReceivedPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error
IgnoreBelow(protocol.PacketNumber)
DropPackets(protocol.EncryptionLevel)
GetAlarmTimeout() time.Time
GetAckFrame(protocol.EncryptionLevel) *wire.AckFrame

View File

@@ -75,9 +75,25 @@ func (h *receivedPacketHandler) IgnoreBelow(pn protocol.PacketNumber) {
h.oneRTTPackets.IgnoreBelow(pn)
}
func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
switch encLevel {
case protocol.EncryptionInitial:
h.initialPackets = nil
case protocol.EncryptionHandshake:
h.handshakePackets = nil
default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
}
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
initialAlarm := h.initialPackets.GetAlarmTimeout()
handshakeAlarm := h.handshakePackets.GetAlarmTimeout()
var initialAlarm, handshakeAlarm time.Time
if h.initialPackets != nil {
initialAlarm = h.initialPackets.GetAlarmTimeout()
}
if h.handshakePackets != nil {
handshakeAlarm = h.handshakePackets.GetAlarmTimeout()
}
oneRTTAlarm := h.oneRTTPackets.GetAlarmTimeout()
return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
}
@@ -86,9 +102,13 @@ func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel) *
var ack *wire.AckFrame
switch encLevel {
case protocol.EncryptionInitial:
ack = h.initialPackets.GetAckFrame()
if h.initialPackets != nil {
ack = h.initialPackets.GetAckFrame()
}
case protocol.EncryptionHandshake:
ack = h.handshakePackets.GetAckFrame()
if h.handshakePackets != nil {
ack = h.handshakePackets.GetAckFrame()
}
case protocol.Encryption1RTT:
return h.oneRTTPackets.GetAckFrame()
default:

View File

@@ -47,4 +47,24 @@ var _ = Describe("Received Packet Handler", func() {
Expect(oneRTTAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5}))
Expect(oneRTTAck.DelayTime).To(BeNumerically("~", time.Second, 50*time.Millisecond))
})
It("drops Initial packets", func() {
sendTime := time.Now().Add(-time.Second)
Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed())
Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed())
Expect(handler.GetAckFrame(protocol.EncryptionInitial)).ToNot(BeNil())
handler.DropPackets(protocol.EncryptionInitial)
Expect(handler.GetAckFrame(protocol.EncryptionInitial)).To(BeNil())
Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).ToNot(BeNil())
})
It("drops Handshake packets", func() {
sendTime := time.Now().Add(-time.Second)
Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed())
Expect(handler.ReceivedPacket(2, protocol.Encryption1RTT, sendTime, true)).To(Succeed())
Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).ToNot(BeNil())
handler.DropPackets(protocol.EncryptionInitial)
Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).To(BeNil())
Expect(handler.GetAckFrame(protocol.Encryption1RTT)).ToNot(BeNil())
})
})

View File

@@ -59,8 +59,7 @@ type sentPacketHandler struct {
congestion congestion.SendAlgorithmWithDebugInfos
rttStats *congestion.RTTStats
handshakeComplete bool
maxAckDelay time.Duration
maxAckDelay time.Duration
// The number of times the crypto packets have been retransmitted without receiving an ack.
cryptoCount uint32
@@ -103,29 +102,32 @@ func NewSentPacketHandler(
}
}
func (h *sentPacketHandler) SetHandshakeComplete() {
h.logger.Debugf("Handshake complete. Discarding all outstanding crypto packets.")
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
// remove outstanding packets from bytes_in_flight
pnSpace := h.getPacketNumberSpace(encLevel)
pnSpace.history.Iterate(func(p *Packet) (bool, error) {
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
return true, nil
})
// remove packets from the retransmission queue
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.Encryption1RTT {
if packet.EncryptionLevel != encLevel {
queue = append(queue, packet)
}
}
for _, pnSpace := range []*packetNumberSpace{h.initialPackets, h.handshakePackets} {
var cryptoPackets []*Packet
pnSpace.history.Iterate(func(p *Packet) (bool, error) {
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
cryptoPackets = append(cryptoPackets, p)
return true, nil
})
for _, p := range cryptoPackets {
pnSpace.history.Remove(p.PacketNumber)
}
}
h.retransmissionQueue = queue
h.handshakeComplete = true
// drop the packet history
switch encLevel {
case protocol.EncryptionInitial:
h.initialPackets = nil
case protocol.EncryptionHandshake:
h.handshakePackets = nil
default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
}
}
func (h *sentPacketHandler) SetMaxAckDelay(mad time.Duration) {
@@ -314,7 +316,14 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(
}
func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool {
return h.initialPackets.history.HasOutstandingPackets() || h.handshakePackets.history.HasOutstandingPackets()
var hasInitial, hasHandshake bool
if h.initialPackets != nil {
hasInitial = h.initialPackets.history.HasOutstandingPackets()
}
if h.handshakePackets != nil {
hasHandshake = h.handshakePackets.history.HasOutstandingPackets()
}
return hasInitial || hasHandshake
}
func (h *sentPacketHandler) hasOutstandingPackets() bool {
@@ -538,8 +547,13 @@ func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) p
}
func (h *sentPacketHandler) SendMode() SendMode {
numTrackedPackets := len(h.retransmissionQueue) + h.initialPackets.history.Len() +
h.handshakePackets.history.Len() + h.oneRTTPackets.history.Len()
numTrackedPackets := len(h.retransmissionQueue) + h.oneRTTPackets.history.Len()
if h.initialPackets != nil {
numTrackedPackets += h.initialPackets.history.Len()
}
if h.handshakePackets != nil {
numTrackedPackets += h.handshakePackets.history.Len()
}
// Don't send any packets if we're keeping track of the maximum number of packets.
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,

View File

@@ -49,7 +49,6 @@ var _ = Describe("SentPacketHandler", func() {
BeforeEach(func() {
rttStats := &congestion.RTTStats{}
handler = NewSentPacketHandler(42, rttStats, utils.DefaultLogger).(*sentPacketHandler)
handler.SetHandshakeComplete()
streamFrame = wire.StreamFrame{
StreamID: 5,
Data: []byte{0x13, 0x37},
@@ -808,10 +807,6 @@ var _ = Describe("SentPacketHandler", func() {
})
Context("crypto packets", func() {
BeforeEach(func() {
handler.handshakeComplete = false
})
It("detects the crypto timeout", func() {
now := time.Now()
sendTime := now.Add(-time.Minute)
@@ -851,24 +846,45 @@ var _ = Describe("SentPacketHandler", func() {
Expect(err).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet"))
})
It("deletes crypto packets when the handshake completes", func() {
It("deletes Initial packets", func() {
for i := protocol.PacketNumber(0); i < 6; i++ {
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionInitial})
handler.SentPacket(p)
}
for i := protocol.PacketNumber(0); i <= 6; i++ {
for i := protocol.PacketNumber(0); i < 10; i++ {
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake})
handler.SentPacket(p)
}
Expect(handler.bytesInFlight).ToNot(BeZero())
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16)))
handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionInitial), handler.getPacketNumberSpace(protocol.EncryptionInitial))
handler.queuePacketForRetransmission(getPacket(3, protocol.EncryptionHandshake), handler.getPacketNumberSpace(protocol.EncryptionHandshake))
handler.SetHandshakeComplete()
Expect(handler.initialPackets.history.Len()).To(BeZero())
Expect(handler.handshakePackets.history.Len()).To(BeZero())
Expect(handler.bytesInFlight).To(BeZero())
lostPacket := getPacket(3, protocol.EncryptionHandshake)
handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake))
handler.DropPackets(protocol.EncryptionInitial)
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10)))
Expect(handler.initialPackets).To(BeNil())
Expect(handler.handshakePackets.history.Len()).ToNot(BeZero())
packet := handler.DequeuePacketForRetransmission()
Expect(packet).To(BeNil())
Expect(packet).To(Equal(lostPacket))
})
It("deletes Handshake packets", func() {
for i := protocol.PacketNumber(0); i < 6; i++ {
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake})
handler.SentPacket(p)
}
for i := protocol.PacketNumber(0); i < 10; i++ {
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.Encryption1RTT})
handler.SentPacket(p)
}
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16)))
handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionHandshake), handler.getPacketNumberSpace(protocol.EncryptionInitial))
lostPacket := getPacket(3, protocol.Encryption1RTT)
handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake))
handler.DropPackets(protocol.EncryptionHandshake)
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10)))
Expect(handler.handshakePackets).To(BeNil())
packet := handler.DequeuePacketForRetransmission()
Expect(packet).To(Equal(lostPacket))
})
})

View File

@@ -53,10 +53,15 @@ func (m messageType) String() string {
}
}
// ErrOpenerNotYetAvailable is returned when an opener is requested for an encryption level,
// but the corresponding opener has not yet been initialized
// This can happen when packets arrive out of order.
var ErrOpenerNotYetAvailable = errors.New("CryptoSetup: opener at this encryption level not yet available")
var (
// ErrOpenerNotYetAvailable is returned when an opener is requested for an encryption level,
// but the corresponding opener has not yet been initialized
// This can happen when packets arrive out of order.
ErrOpenerNotYetAvailable = errors.New("CryptoSetup: opener at this encryption level not yet available")
// ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level,
// but the corresponding keys have already been dropped.
ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped")
)
type cryptoSetup struct {
tlsConf *qtls.Config
@@ -67,6 +72,8 @@ type cryptoSetup struct {
paramsChan <-chan []byte
handleParamsCallback func([]byte)
dropKeyCallback func(protocol.EncryptionLevel)
alertChan chan uint8
// HandleData() sends errors on the messageErrChan
messageErrChan chan error
@@ -121,6 +128,7 @@ func NewCryptoSetupClient(
remoteAddr net.Addr,
tp *TransportParameters,
handleParams func([]byte),
dropKeys func(protocol.EncryptionLevel),
tlsConf *tls.Config,
logger utils.Logger,
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
@@ -131,6 +139,7 @@ func NewCryptoSetupClient(
connID,
tp,
handleParams,
dropKeys,
tlsConf,
logger,
protocol.PerspectiveClient,
@@ -151,6 +160,7 @@ func NewCryptoSetupServer(
remoteAddr net.Addr,
tp *TransportParameters,
handleParams func([]byte),
dropKeys func(protocol.EncryptionLevel),
tlsConf *tls.Config,
logger utils.Logger,
) (CryptoSetup, error) {
@@ -161,6 +171,7 @@ func NewCryptoSetupServer(
connID,
tp,
handleParams,
dropKeys,
tlsConf,
logger,
protocol.PerspectiveServer,
@@ -179,6 +190,7 @@ func newCryptoSetup(
connID protocol.ConnectionID,
tp *TransportParameters,
handleParams func([]byte),
dropKeys func(protocol.EncryptionLevel),
tlsConf *tls.Config,
logger utils.Logger,
perspective protocol.Perspective,
@@ -197,6 +209,7 @@ func newCryptoSetup(
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
handleParamsCallback: handleParams,
dropKeyCallback: dropKeys,
paramsChan: extHandler.TransportParameters(),
logger: logger,
perspective: perspective,
@@ -225,6 +238,24 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) error {
return nil
}
func (h *cryptoSetup) Received1RTTAck() {
// drop initial keys
// TODO: do this earlier
if h.initialOpener != nil {
h.initialOpener = nil
h.initialSealer = nil
h.dropKeyCallback(protocol.EncryptionInitial)
h.logger.Debugf("Dropping Initial keys.")
}
// drop handshake keys
if h.handshakeOpener != nil {
h.handshakeOpener = nil
h.handshakeSealer = nil
h.logger.Debugf("Dropping Handshake keys.")
h.dropKeyCallback(protocol.EncryptionHandshake)
}
}
func (h *cryptoSetup) RunHandshake() error {
// Handle errors that might occur when HandleData() is called.
handshakeComplete := make(chan struct{})
@@ -554,10 +585,17 @@ func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error)
switch level {
case protocol.EncryptionInitial:
if h.initialOpener == nil {
return nil, ErrKeysDropped
}
return h.initialOpener, nil
case protocol.EncryptionHandshake:
if h.handshakeOpener == nil {
return nil, ErrOpenerNotYetAvailable
if h.initialOpener != nil {
return nil, ErrOpenerNotYetAvailable
}
// if the initial opener is also not available, the keys were already dropped
return nil, ErrKeysDropped
}
return h.handshakeOpener, nil
case protocol.Encryption1RTT:

View File

@@ -87,6 +87,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil,
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
tlsConf,
utils.DefaultLogger.WithPrefix("server"),
)
@@ -115,6 +116,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil,
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)
@@ -149,6 +151,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil,
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)
@@ -177,6 +180,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil,
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)
@@ -256,6 +260,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil,
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
clientConf,
utils.DefaultLogger.WithPrefix("client"),
)
@@ -271,6 +276,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil,
&TransportParameters{StatelessResetToken: &token},
func([]byte) {},
func(protocol.EncryptionLevel) {},
serverConf,
utils.DefaultLogger.WithPrefix("server"),
)
@@ -313,6 +319,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil,
&TransportParameters{},
func([]byte) {},
func(protocol.EncryptionLevel) {},
&tls.Config{InsecureSkipVerify: true},
utils.DefaultLogger.WithPrefix("client"),
)
@@ -350,6 +357,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil,
cTransportParameters,
func(p []byte) { sTransportParametersRcvd = p },
func(protocol.EncryptionLevel) {},
clientConf,
utils.DefaultLogger.WithPrefix("client"),
)
@@ -369,6 +377,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil,
sTransportParameters,
func(p []byte) { cTransportParametersRcvd = p },
func(protocol.EncryptionLevel) {},
testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"),
)

View File

@@ -36,6 +36,7 @@ type CryptoSetup interface {
ChangeConnectionID(protocol.ConnectionID) error
HandleMessage([]byte, protocol.EncryptionLevel) bool
Received1RTTAck()
ConnectionState() tls.ConnectionState
GetSealer() (protocol.EncryptionLevel, Sealer)

View File

@@ -36,6 +36,18 @@ func (m *MockReceivedPacketHandler) EXPECT() *MockReceivedPacketHandlerMockRecor
return m.recorder
}
// DropPackets mocks base method
func (m *MockReceivedPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DropPackets", arg0)
}
// DropPackets indicates an expected call of DropPackets
func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockReceivedPacketHandler)(nil).DropPackets), arg0)
}
// GetAckFrame mocks base method
func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel) *wire.AckFrame {
m.ctrl.T.Helper()

View File

@@ -66,6 +66,18 @@ func (mr *MockSentPacketHandlerMockRecorder) DequeueProbePacket() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeueProbePacket))
}
// DropPackets mocks base method
func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DropPackets", arg0)
}
// DropPackets indicates an expected call of DropPackets
func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0)
}
// GetAlarmTimeout mocks base method
func (m *MockSentPacketHandler) GetAlarmTimeout() time.Time {
m.ctrl.T.Helper()
@@ -203,18 +215,6 @@ func (mr *MockSentPacketHandlerMockRecorder) SentPacketsAsRetransmission(arg0, a
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacketsAsRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacketsAsRetransmission), arg0, arg1)
}
// SetHandshakeComplete mocks base method
func (m *MockSentPacketHandler) SetHandshakeComplete() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetHandshakeComplete")
}
// SetHandshakeComplete indicates an expected call of SetHandshakeComplete
func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeComplete() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeComplete", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeComplete))
}
// SetMaxAckDelay mocks base method
func (m *MockSentPacketHandler) SetMaxAckDelay(arg0 time.Duration) {
m.ctrl.T.Helper()

View File

@@ -137,6 +137,18 @@ func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1)
}
// Received1RTTAck mocks base method
func (m *MockCryptoSetup) Received1RTTAck() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Received1RTTAck")
}
// Received1RTTAck indicates an expected call of Received1RTTAck
func (mr *MockCryptoSetupMockRecorder) Received1RTTAck() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Received1RTTAck", reflect.TypeOf((*MockCryptoSetup)(nil).Received1RTTAck))
}
// RunHandshake mocks base method
func (m *MockCryptoSetup) RunHandshake() error {
m.ctrl.T.Helper()