From 195bdc9944215bf51ebf6b9d9c955991f94c4ced Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 29 May 2019 13:07:05 +0100 Subject: [PATCH 1/4] remove unused handshakeComplete member variable from sent packet handler --- internal/ackhandler/sent_packet_handler.go | 4 +--- internal/ackhandler/sent_packet_handler_test.go | 4 ---- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 7d997ead..4c002b8e 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -59,8 +59,7 @@ type sentPacketHandler struct { congestion congestion.SendAlgorithmWithDebugInfos rttStats *congestion.RTTStats - handshakeComplete bool - maxAckDelay time.Duration + maxAckDelay time.Duration // The number of times the crypto packets have been retransmitted without receiving an ack. cryptoCount uint32 @@ -125,7 +124,6 @@ func (h *sentPacketHandler) SetHandshakeComplete() { } } h.retransmissionQueue = queue - h.handshakeComplete = true } func (h *sentPacketHandler) SetMaxAckDelay(mad time.Duration) { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index efc20890..e5b1a66e 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -808,10 +808,6 @@ var _ = Describe("SentPacketHandler", func() { }) Context("crypto packets", func() { - BeforeEach(func() { - handler.handshakeComplete = false - }) - It("detects the crypto timeout", func() { now := time.Now() sendTime := now.Add(-time.Minute) From 4d5b4fd790608bd9d748c8ae6781411e2209d332 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 30 May 2019 01:01:27 +0800 Subject: [PATCH 2/4] add a function to drop sent packets of a certain encryption level --- internal/ackhandler/interfaces.go | 2 +- internal/ackhandler/sent_packet_handler.go | 32 ++++++++-------- .../ackhandler/sent_packet_handler_test.go | 38 ++++++++++++++----- .../mocks/ackhandler/sent_packet_handler.go | 24 ++++++------ session.go | 6 ++- 5 files changed, 62 insertions(+), 40 deletions(-) diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index d3550feb..4f667ac3 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -14,7 +14,7 @@ type SentPacketHandler interface { SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error SetMaxAckDelay(time.Duration) - SetHandshakeComplete() + DropPackets(protocol.EncryptionLevel) ResetForRetry() error // The SendMode determines if and what kind of packets can be sent. diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 4c002b8e..acfdf4ff 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -102,27 +102,27 @@ func NewSentPacketHandler( } } -func (h *sentPacketHandler) SetHandshakeComplete() { - h.logger.Debugf("Handshake complete. Discarding all outstanding crypto packets.") +func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { + // remove outstanding packets from bytes_in_flight + pnSpace := h.getPacketNumberSpace(encLevel) + var packets []*Packet + pnSpace.history.Iterate(func(p *Packet) (bool, error) { + packets = append(packets, p) + if p.includedInBytesInFlight { + h.bytesInFlight -= p.Length + } + return true, nil + }) + for _, p := range packets { + pnSpace.history.Remove(p.PacketNumber) + } + // remove packets from the retransmission queue var queue []*Packet for _, packet := range h.retransmissionQueue { - if packet.EncryptionLevel == protocol.Encryption1RTT { + if packet.EncryptionLevel != encLevel { queue = append(queue, packet) } } - for _, pnSpace := range []*packetNumberSpace{h.initialPackets, h.handshakePackets} { - var cryptoPackets []*Packet - pnSpace.history.Iterate(func(p *Packet) (bool, error) { - if p.includedInBytesInFlight { - h.bytesInFlight -= p.Length - } - cryptoPackets = append(cryptoPackets, p) - return true, nil - }) - for _, p := range cryptoPackets { - pnSpace.history.Remove(p.PacketNumber) - } - } h.retransmissionQueue = queue } diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index e5b1a66e..ccb4ba6b 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -49,7 +49,6 @@ var _ = Describe("SentPacketHandler", func() { BeforeEach(func() { rttStats := &congestion.RTTStats{} handler = NewSentPacketHandler(42, rttStats, utils.DefaultLogger).(*sentPacketHandler) - handler.SetHandshakeComplete() streamFrame = wire.StreamFrame{ StreamID: 5, Data: []byte{0x13, 0x37}, @@ -847,24 +846,45 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet")) }) - It("deletes crypto packets when the handshake completes", func() { + It("deletes Initial packets", func() { for i := protocol.PacketNumber(0); i < 6; i++ { p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionInitial}) handler.SentPacket(p) } - for i := protocol.PacketNumber(0); i <= 6; i++ { + for i := protocol.PacketNumber(0); i < 10; i++ { p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake}) handler.SentPacket(p) } - Expect(handler.bytesInFlight).ToNot(BeZero()) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16))) handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionInitial), handler.getPacketNumberSpace(protocol.EncryptionInitial)) - handler.queuePacketForRetransmission(getPacket(3, protocol.EncryptionHandshake), handler.getPacketNumberSpace(protocol.EncryptionHandshake)) - handler.SetHandshakeComplete() + lostPacket := getPacket(3, protocol.EncryptionHandshake) + handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake)) + handler.DropPackets(protocol.EncryptionInitial) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) Expect(handler.initialPackets.history.Len()).To(BeZero()) - Expect(handler.handshakePackets.history.Len()).To(BeZero()) - Expect(handler.bytesInFlight).To(BeZero()) + Expect(handler.handshakePackets.history.Len()).ToNot(BeZero()) packet := handler.DequeuePacketForRetransmission() - Expect(packet).To(BeNil()) + Expect(packet).To(Equal(lostPacket)) + }) + + It("deletes Handshake packets", func() { + for i := protocol.PacketNumber(0); i < 6; i++ { + p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake}) + handler.SentPacket(p) + } + for i := protocol.PacketNumber(0); i < 10; i++ { + p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.Encryption1RTT}) + handler.SentPacket(p) + } + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16))) + handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionHandshake), handler.getPacketNumberSpace(protocol.EncryptionInitial)) + lostPacket := getPacket(3, protocol.Encryption1RTT) + handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake)) + handler.DropPackets(protocol.EncryptionHandshake) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) + Expect(handler.handshakePackets.history.Len()).To(BeZero()) + packet := handler.DequeuePacketForRetransmission() + Expect(packet).To(Equal(lostPacket)) }) }) diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index a712c816..fab8ff08 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -66,6 +66,18 @@ func (mr *MockSentPacketHandlerMockRecorder) DequeueProbePacket() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeueProbePacket)) } +// DropPackets mocks base method +func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DropPackets", arg0) +} + +// DropPackets indicates an expected call of DropPackets +func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) +} + // GetAlarmTimeout mocks base method func (m *MockSentPacketHandler) GetAlarmTimeout() time.Time { m.ctrl.T.Helper() @@ -203,18 +215,6 @@ func (mr *MockSentPacketHandlerMockRecorder) SentPacketsAsRetransmission(arg0, a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacketsAsRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacketsAsRetransmission), arg0, arg1) } -// SetHandshakeComplete mocks base method -func (m *MockSentPacketHandler) SetHandshakeComplete() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetHandshakeComplete") -} - -// SetHandshakeComplete indicates an expected call of SetHandshakeComplete -func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeComplete() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeComplete", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeComplete)) -} - // SetMaxAckDelay mocks base method func (m *MockSentPacketHandler) SetMaxAckDelay(arg0 time.Duration) { m.ctrl.T.Helper() diff --git a/session.go b/session.go index 4b37ce77..0728ede3 100644 --- a/session.go +++ b/session.go @@ -485,7 +485,8 @@ func (s *session) handleHandshakeComplete() { // independent from the application protocol. if s.perspective == protocol.PerspectiveServer { s.queueControlFrame(&wire.PingFrame{}) - s.sentPacketHandler.SetHandshakeComplete() + s.sentPacketHandler.DropPackets(protocol.EncryptionInitial) + s.sentPacketHandler.DropPackets(protocol.EncryptionHandshake) } } @@ -646,7 +647,8 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time if s.perspective == protocol.PerspectiveClient { if !s.receivedFirstForwardSecurePacket && packet.encryptionLevel == protocol.Encryption1RTT { s.receivedFirstForwardSecurePacket = true - s.sentPacketHandler.SetHandshakeComplete() + s.sentPacketHandler.DropPackets(protocol.EncryptionInitial) + s.sentPacketHandler.DropPackets(protocol.EncryptionHandshake) } } From 4834962cbdee796a11762937dd0cacbde9075bcb Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 30 May 2019 13:46:45 +0800 Subject: [PATCH 3/4] add a function to drop received packets of a certain encryption level --- internal/ackhandler/interfaces.go | 1 + .../ackhandler/received_packet_handler.go | 28 ++++++++++++++++--- .../received_packet_handler_test.go | 20 +++++++++++++ .../ackhandler/received_packet_handler.go | 12 ++++++++ 4 files changed, 57 insertions(+), 4 deletions(-) diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 4f667ac3..b1abad0a 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -45,6 +45,7 @@ type SentPacketHandler interface { type ReceivedPacketHandler interface { ReceivedPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error IgnoreBelow(protocol.PacketNumber) + DropPackets(protocol.EncryptionLevel) GetAlarmTimeout() time.Time GetAckFrame(protocol.EncryptionLevel) *wire.AckFrame diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index b457fb4b..1035bd12 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -75,9 +75,25 @@ func (h *receivedPacketHandler) IgnoreBelow(pn protocol.PacketNumber) { h.oneRTTPackets.IgnoreBelow(pn) } +func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { + switch encLevel { + case protocol.EncryptionInitial: + h.initialPackets = nil + case protocol.EncryptionHandshake: + h.handshakePackets = nil + default: + panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) + } +} + func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { - initialAlarm := h.initialPackets.GetAlarmTimeout() - handshakeAlarm := h.handshakePackets.GetAlarmTimeout() + var initialAlarm, handshakeAlarm time.Time + if h.initialPackets != nil { + initialAlarm = h.initialPackets.GetAlarmTimeout() + } + if h.handshakePackets != nil { + handshakeAlarm = h.handshakePackets.GetAlarmTimeout() + } oneRTTAlarm := h.oneRTTPackets.GetAlarmTimeout() return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm) } @@ -86,9 +102,13 @@ func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel) * var ack *wire.AckFrame switch encLevel { case protocol.EncryptionInitial: - ack = h.initialPackets.GetAckFrame() + if h.initialPackets != nil { + ack = h.initialPackets.GetAckFrame() + } case protocol.EncryptionHandshake: - ack = h.handshakePackets.GetAckFrame() + if h.handshakePackets != nil { + ack = h.handshakePackets.GetAckFrame() + } case protocol.Encryption1RTT: return h.oneRTTPackets.GetAckFrame() default: diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index 80a26b7f..7bcea8e3 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -47,4 +47,24 @@ var _ = Describe("Received Packet Handler", func() { Expect(oneRTTAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5})) Expect(oneRTTAck.DelayTime).To(BeNumerically("~", time.Second, 50*time.Millisecond)) }) + + It("drops Initial packets", func() { + sendTime := time.Now().Add(-time.Second) + Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.GetAckFrame(protocol.EncryptionInitial)).ToNot(BeNil()) + handler.DropPackets(protocol.EncryptionInitial) + Expect(handler.GetAckFrame(protocol.EncryptionInitial)).To(BeNil()) + Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).ToNot(BeNil()) + }) + + It("drops Handshake packets", func() { + sendTime := time.Now().Add(-time.Second) + Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).ToNot(BeNil()) + handler.DropPackets(protocol.EncryptionInitial) + Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).To(BeNil()) + Expect(handler.GetAckFrame(protocol.Encryption1RTT)).ToNot(BeNil()) + }) }) diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go index 4c395f27..640b725a 100644 --- a/internal/mocks/ackhandler/received_packet_handler.go +++ b/internal/mocks/ackhandler/received_packet_handler.go @@ -36,6 +36,18 @@ func (m *MockReceivedPacketHandler) EXPECT() *MockReceivedPacketHandlerMockRecor return m.recorder } +// DropPackets mocks base method +func (m *MockReceivedPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DropPackets", arg0) +} + +// DropPackets indicates an expected call of DropPackets +func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockReceivedPacketHandler)(nil).DropPackets), arg0) +} + // GetAckFrame mocks base method func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel) *wire.AckFrame { m.ctrl.T.Helper() From a4989c3d9ced7d3565d823032f50932d5a9d9984 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 30 May 2019 02:23:07 +0800 Subject: [PATCH 4/4] drop Initial and Handshake keys when receiving the first 1-RTT ACK --- internal/ackhandler/sent_packet_handler.go | 32 +++++++++---- .../ackhandler/sent_packet_handler_test.go | 4 +- internal/handshake/crypto_setup.go | 48 +++++++++++++++++-- internal/handshake/crypto_setup_test.go | 9 ++++ internal/handshake/interface.go | 1 + internal/mocks/crypto_setup.go | 12 +++++ session.go | 39 ++++++++------- session_test.go | 1 + 8 files changed, 111 insertions(+), 35 deletions(-) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index acfdf4ff..3570ff98 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -105,17 +105,12 @@ func NewSentPacketHandler( func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { // remove outstanding packets from bytes_in_flight pnSpace := h.getPacketNumberSpace(encLevel) - var packets []*Packet pnSpace.history.Iterate(func(p *Packet) (bool, error) { - packets = append(packets, p) if p.includedInBytesInFlight { h.bytesInFlight -= p.Length } return true, nil }) - for _, p := range packets { - pnSpace.history.Remove(p.PacketNumber) - } // remove packets from the retransmission queue var queue []*Packet for _, packet := range h.retransmissionQueue { @@ -124,6 +119,15 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { } } h.retransmissionQueue = queue + // drop the packet history + switch encLevel { + case protocol.EncryptionInitial: + h.initialPackets = nil + case protocol.EncryptionHandshake: + h.handshakePackets = nil + default: + panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) + } } func (h *sentPacketHandler) SetMaxAckDelay(mad time.Duration) { @@ -312,7 +316,14 @@ func (h *sentPacketHandler) determineNewlyAckedPackets( } func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { - return h.initialPackets.history.HasOutstandingPackets() || h.handshakePackets.history.HasOutstandingPackets() + var hasInitial, hasHandshake bool + if h.initialPackets != nil { + hasInitial = h.initialPackets.history.HasOutstandingPackets() + } + if h.handshakePackets != nil { + hasHandshake = h.handshakePackets.history.HasOutstandingPackets() + } + return hasInitial || hasHandshake } func (h *sentPacketHandler) hasOutstandingPackets() bool { @@ -536,8 +547,13 @@ func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) p } func (h *sentPacketHandler) SendMode() SendMode { - numTrackedPackets := len(h.retransmissionQueue) + h.initialPackets.history.Len() + - h.handshakePackets.history.Len() + h.oneRTTPackets.history.Len() + numTrackedPackets := len(h.retransmissionQueue) + h.oneRTTPackets.history.Len() + if h.initialPackets != nil { + numTrackedPackets += h.initialPackets.history.Len() + } + if h.handshakePackets != nil { + numTrackedPackets += h.handshakePackets.history.Len() + } // Don't send any packets if we're keeping track of the maximum number of packets. // Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets, diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index ccb4ba6b..948c8b4b 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -861,7 +861,7 @@ var _ = Describe("SentPacketHandler", func() { handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake)) handler.DropPackets(protocol.EncryptionInitial) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) - Expect(handler.initialPackets.history.Len()).To(BeZero()) + Expect(handler.initialPackets).To(BeNil()) Expect(handler.handshakePackets.history.Len()).ToNot(BeZero()) packet := handler.DequeuePacketForRetransmission() Expect(packet).To(Equal(lostPacket)) @@ -882,7 +882,7 @@ var _ = Describe("SentPacketHandler", func() { handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake)) handler.DropPackets(protocol.EncryptionHandshake) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) - Expect(handler.handshakePackets.history.Len()).To(BeZero()) + Expect(handler.handshakePackets).To(BeNil()) packet := handler.DequeuePacketForRetransmission() Expect(packet).To(Equal(lostPacket)) }) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index de6ddc68..e7cdea4f 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -53,10 +53,15 @@ func (m messageType) String() string { } } -// ErrOpenerNotYetAvailable is returned when an opener is requested for an encryption level, -// but the corresponding opener has not yet been initialized -// This can happen when packets arrive out of order. -var ErrOpenerNotYetAvailable = errors.New("CryptoSetup: opener at this encryption level not yet available") +var ( + // ErrOpenerNotYetAvailable is returned when an opener is requested for an encryption level, + // but the corresponding opener has not yet been initialized + // This can happen when packets arrive out of order. + ErrOpenerNotYetAvailable = errors.New("CryptoSetup: opener at this encryption level not yet available") + // ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level, + // but the corresponding keys have already been dropped. + ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped") +) type cryptoSetup struct { tlsConf *qtls.Config @@ -67,6 +72,8 @@ type cryptoSetup struct { paramsChan <-chan []byte handleParamsCallback func([]byte) + dropKeyCallback func(protocol.EncryptionLevel) + alertChan chan uint8 // HandleData() sends errors on the messageErrChan messageErrChan chan error @@ -121,6 +128,7 @@ func NewCryptoSetupClient( remoteAddr net.Addr, tp *TransportParameters, handleParams func([]byte), + dropKeys func(protocol.EncryptionLevel), tlsConf *tls.Config, logger utils.Logger, ) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) { @@ -131,6 +139,7 @@ func NewCryptoSetupClient( connID, tp, handleParams, + dropKeys, tlsConf, logger, protocol.PerspectiveClient, @@ -151,6 +160,7 @@ func NewCryptoSetupServer( remoteAddr net.Addr, tp *TransportParameters, handleParams func([]byte), + dropKeys func(protocol.EncryptionLevel), tlsConf *tls.Config, logger utils.Logger, ) (CryptoSetup, error) { @@ -161,6 +171,7 @@ func NewCryptoSetupServer( connID, tp, handleParams, + dropKeys, tlsConf, logger, protocol.PerspectiveServer, @@ -179,6 +190,7 @@ func newCryptoSetup( connID protocol.ConnectionID, tp *TransportParameters, handleParams func([]byte), + dropKeys func(protocol.EncryptionLevel), tlsConf *tls.Config, logger utils.Logger, perspective protocol.Perspective, @@ -197,6 +209,7 @@ func newCryptoSetup( readEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial, handleParamsCallback: handleParams, + dropKeyCallback: dropKeys, paramsChan: extHandler.TransportParameters(), logger: logger, perspective: perspective, @@ -225,6 +238,24 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) error { return nil } +func (h *cryptoSetup) Received1RTTAck() { + // drop initial keys + // TODO: do this earlier + if h.initialOpener != nil { + h.initialOpener = nil + h.initialSealer = nil + h.dropKeyCallback(protocol.EncryptionInitial) + h.logger.Debugf("Dropping Initial keys.") + } + // drop handshake keys + if h.handshakeOpener != nil { + h.handshakeOpener = nil + h.handshakeSealer = nil + h.logger.Debugf("Dropping Handshake keys.") + h.dropKeyCallback(protocol.EncryptionHandshake) + } +} + func (h *cryptoSetup) RunHandshake() error { // Handle errors that might occur when HandleData() is called. handshakeComplete := make(chan struct{}) @@ -554,10 +585,17 @@ func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) switch level { case protocol.EncryptionInitial: + if h.initialOpener == nil { + return nil, ErrKeysDropped + } return h.initialOpener, nil case protocol.EncryptionHandshake: if h.handshakeOpener == nil { - return nil, ErrOpenerNotYetAvailable + if h.initialOpener != nil { + return nil, ErrOpenerNotYetAvailable + } + // if the initial opener is also not available, the keys were already dropped + return nil, ErrKeysDropped } return h.handshakeOpener, nil case protocol.Encryption1RTT: diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 63d5eb1b..8743a631 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -87,6 +87,7 @@ var _ = Describe("Crypto Setup TLS", func() { nil, &TransportParameters{}, func([]byte) {}, + func(protocol.EncryptionLevel) {}, tlsConf, utils.DefaultLogger.WithPrefix("server"), ) @@ -115,6 +116,7 @@ var _ = Describe("Crypto Setup TLS", func() { nil, &TransportParameters{}, func([]byte) {}, + func(protocol.EncryptionLevel) {}, testdata.GetTLSConfig(), utils.DefaultLogger.WithPrefix("server"), ) @@ -149,6 +151,7 @@ var _ = Describe("Crypto Setup TLS", func() { nil, &TransportParameters{}, func([]byte) {}, + func(protocol.EncryptionLevel) {}, testdata.GetTLSConfig(), utils.DefaultLogger.WithPrefix("server"), ) @@ -177,6 +180,7 @@ var _ = Describe("Crypto Setup TLS", func() { nil, &TransportParameters{}, func([]byte) {}, + func(protocol.EncryptionLevel) {}, testdata.GetTLSConfig(), utils.DefaultLogger.WithPrefix("server"), ) @@ -256,6 +260,7 @@ var _ = Describe("Crypto Setup TLS", func() { nil, &TransportParameters{}, func([]byte) {}, + func(protocol.EncryptionLevel) {}, clientConf, utils.DefaultLogger.WithPrefix("client"), ) @@ -271,6 +276,7 @@ var _ = Describe("Crypto Setup TLS", func() { nil, &TransportParameters{StatelessResetToken: &token}, func([]byte) {}, + func(protocol.EncryptionLevel) {}, serverConf, utils.DefaultLogger.WithPrefix("server"), ) @@ -313,6 +319,7 @@ var _ = Describe("Crypto Setup TLS", func() { nil, &TransportParameters{}, func([]byte) {}, + func(protocol.EncryptionLevel) {}, &tls.Config{InsecureSkipVerify: true}, utils.DefaultLogger.WithPrefix("client"), ) @@ -350,6 +357,7 @@ var _ = Describe("Crypto Setup TLS", func() { nil, cTransportParameters, func(p []byte) { sTransportParametersRcvd = p }, + func(protocol.EncryptionLevel) {}, clientConf, utils.DefaultLogger.WithPrefix("client"), ) @@ -369,6 +377,7 @@ var _ = Describe("Crypto Setup TLS", func() { nil, sTransportParameters, func(p []byte) { cTransportParametersRcvd = p }, + func(protocol.EncryptionLevel) {}, testdata.GetTLSConfig(), utils.DefaultLogger.WithPrefix("server"), ) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 22522513..dd893e55 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -36,6 +36,7 @@ type CryptoSetup interface { ChangeConnectionID(protocol.ConnectionID) error HandleMessage([]byte, protocol.EncryptionLevel) bool + Received1RTTAck() ConnectionState() tls.ConnectionState GetSealer() (protocol.EncryptionLevel, Sealer) diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index 01e9bea7..4983c371 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -137,6 +137,18 @@ func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) } +// Received1RTTAck mocks base method +func (m *MockCryptoSetup) Received1RTTAck() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Received1RTTAck") +} + +// Received1RTTAck indicates an expected call of Received1RTTAck +func (mr *MockCryptoSetupMockRecorder) Received1RTTAck() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Received1RTTAck", reflect.TypeOf((*MockCryptoSetup)(nil).Received1RTTAck)) +} + // RunHandshake mocks base method func (m *MockCryptoSetup) RunHandshake() error { m.ctrl.T.Helper() diff --git a/session.go b/session.go index 0728ede3..44bea138 100644 --- a/session.go +++ b/session.go @@ -49,6 +49,7 @@ type streamManager interface { type cryptoStreamHandler interface { RunHandshake() error ChangeConnectionID(protocol.ConnectionID) error + Received1RTTAck() io.Closer ConnectionState() tls.ConnectionState } @@ -129,9 +130,8 @@ type session struct { handshakeCompleteChan chan struct{} // is closed when the handshake completes handshakeComplete bool - receivedRetry bool - receivedFirstPacket bool - receivedFirstForwardSecurePacket bool + receivedRetry bool + receivedFirstPacket bool sessionCreationTime time.Time // The idle timeout is set based on the max of the time we received the last packet... @@ -199,6 +199,7 @@ var newSession = func( conn.RemoteAddr(), params, s.processTransportParameters, + s.dropEncryptionLevel, tlsConf, logger, ) @@ -267,6 +268,7 @@ var newClientSession = func( conn.RemoteAddr(), params, s.processTransportParameters, + s.dropEncryptionLevel, tlsConf, logger, ) @@ -485,8 +487,6 @@ func (s *session) handleHandshakeComplete() { // independent from the application protocol. if s.perspective == protocol.PerspectiveServer { s.queueControlFrame(&wire.PingFrame{}) - s.sentPacketHandler.DropPackets(protocol.EncryptionInitial) - s.sentPacketHandler.DropPackets(protocol.EncryptionHandshake) } } @@ -560,16 +560,19 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool / packet, err := s.unpacker.Unpack(hdr, p.data) if err != nil { - if err == handshake.ErrOpenerNotYetAvailable { + switch err { + case handshake.ErrKeysDropped: + s.logger.Debugf("Dropping packet because we already dropped the keys.") + case handshake.ErrOpenerNotYetAvailable: // Sealer for this encryption level not yet available. // Try again later. wasQueued = true s.tryQueueingUndecryptablePacket(p) - return false + default: + // This might be a packet injected by an attacker. + // Drop it. + s.logger.Debugf("Dropping packet that could not be unpacked. Unpack error: %s", err) } - // This might be a packet injected by an attacker. - // Drop it. - s.logger.Debugf("Dropping packet that could not be unpacked. Unpack error: %s", err) return false } @@ -642,16 +645,6 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} s.keepAlivePingSent = false - // The client completes the handshake first (after sending the CFIN). - // We know that the server completed the handshake as soon as we receive a forward-secure packet. - if s.perspective == protocol.PerspectiveClient { - if !s.receivedFirstForwardSecurePacket && packet.encryptionLevel == protocol.Encryption1RTT { - s.receivedFirstForwardSecurePacket = true - s.sentPacketHandler.DropPackets(protocol.EncryptionInitial) - s.sentPacketHandler.DropPackets(protocol.EncryptionHandshake) - } - } - r := bytes.NewReader(packet.data) var isAckEliciting bool for { @@ -834,6 +827,7 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber, } if encLevel == protocol.Encryption1RTT { s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked()) + s.cryptoStreamHandler.Received1RTTAck() } return nil } @@ -924,6 +918,11 @@ func (s *session) handleCloseError(closeErr closeError) { } } +func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { + s.sentPacketHandler.DropPackets(encLevel) + s.receivedPacketHandler.DropPackets(encLevel) +} + func (s *session) processTransportParameters(data []byte) { var params *handshake.TransportParameters var err error diff --git a/session_test.go b/session_test.go index 1a72f2f4..afe51298 100644 --- a/session_test.go +++ b/session_test.go @@ -161,6 +161,7 @@ var _ = Describe("Session", func() { }) It("tells the ReceivedPacketHandler to ignore low ranges", func() { + cryptoSetup.EXPECT().Received1RTTAck() ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 3}}} sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().ReceivedAck(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())