From 4d5b4fd790608bd9d748c8ae6781411e2209d332 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 30 May 2019 01:01:27 +0800 Subject: [PATCH] add a function to drop sent packets of a certain encryption level --- internal/ackhandler/interfaces.go | 2 +- internal/ackhandler/sent_packet_handler.go | 32 ++++++++-------- .../ackhandler/sent_packet_handler_test.go | 38 ++++++++++++++----- .../mocks/ackhandler/sent_packet_handler.go | 24 ++++++------ session.go | 6 ++- 5 files changed, 62 insertions(+), 40 deletions(-) diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index d3550feb..4f667ac3 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -14,7 +14,7 @@ type SentPacketHandler interface { SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error SetMaxAckDelay(time.Duration) - SetHandshakeComplete() + DropPackets(protocol.EncryptionLevel) ResetForRetry() error // The SendMode determines if and what kind of packets can be sent. diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 4c002b8e..acfdf4ff 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -102,27 +102,27 @@ func NewSentPacketHandler( } } -func (h *sentPacketHandler) SetHandshakeComplete() { - h.logger.Debugf("Handshake complete. Discarding all outstanding crypto packets.") +func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { + // remove outstanding packets from bytes_in_flight + pnSpace := h.getPacketNumberSpace(encLevel) + var packets []*Packet + pnSpace.history.Iterate(func(p *Packet) (bool, error) { + packets = append(packets, p) + if p.includedInBytesInFlight { + h.bytesInFlight -= p.Length + } + return true, nil + }) + for _, p := range packets { + pnSpace.history.Remove(p.PacketNumber) + } + // remove packets from the retransmission queue var queue []*Packet for _, packet := range h.retransmissionQueue { - if packet.EncryptionLevel == protocol.Encryption1RTT { + if packet.EncryptionLevel != encLevel { queue = append(queue, packet) } } - for _, pnSpace := range []*packetNumberSpace{h.initialPackets, h.handshakePackets} { - var cryptoPackets []*Packet - pnSpace.history.Iterate(func(p *Packet) (bool, error) { - if p.includedInBytesInFlight { - h.bytesInFlight -= p.Length - } - cryptoPackets = append(cryptoPackets, p) - return true, nil - }) - for _, p := range cryptoPackets { - pnSpace.history.Remove(p.PacketNumber) - } - } h.retransmissionQueue = queue } diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index e5b1a66e..ccb4ba6b 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -49,7 +49,6 @@ var _ = Describe("SentPacketHandler", func() { BeforeEach(func() { rttStats := &congestion.RTTStats{} handler = NewSentPacketHandler(42, rttStats, utils.DefaultLogger).(*sentPacketHandler) - handler.SetHandshakeComplete() streamFrame = wire.StreamFrame{ StreamID: 5, Data: []byte{0x13, 0x37}, @@ -847,24 +846,45 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet")) }) - It("deletes crypto packets when the handshake completes", func() { + It("deletes Initial packets", func() { for i := protocol.PacketNumber(0); i < 6; i++ { p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionInitial}) handler.SentPacket(p) } - for i := protocol.PacketNumber(0); i <= 6; i++ { + for i := protocol.PacketNumber(0); i < 10; i++ { p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake}) handler.SentPacket(p) } - Expect(handler.bytesInFlight).ToNot(BeZero()) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16))) handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionInitial), handler.getPacketNumberSpace(protocol.EncryptionInitial)) - handler.queuePacketForRetransmission(getPacket(3, protocol.EncryptionHandshake), handler.getPacketNumberSpace(protocol.EncryptionHandshake)) - handler.SetHandshakeComplete() + lostPacket := getPacket(3, protocol.EncryptionHandshake) + handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake)) + handler.DropPackets(protocol.EncryptionInitial) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) Expect(handler.initialPackets.history.Len()).To(BeZero()) - Expect(handler.handshakePackets.history.Len()).To(BeZero()) - Expect(handler.bytesInFlight).To(BeZero()) + Expect(handler.handshakePackets.history.Len()).ToNot(BeZero()) packet := handler.DequeuePacketForRetransmission() - Expect(packet).To(BeNil()) + Expect(packet).To(Equal(lostPacket)) + }) + + It("deletes Handshake packets", func() { + for i := protocol.PacketNumber(0); i < 6; i++ { + p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake}) + handler.SentPacket(p) + } + for i := protocol.PacketNumber(0); i < 10; i++ { + p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.Encryption1RTT}) + handler.SentPacket(p) + } + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16))) + handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionHandshake), handler.getPacketNumberSpace(protocol.EncryptionInitial)) + lostPacket := getPacket(3, protocol.Encryption1RTT) + handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake)) + handler.DropPackets(protocol.EncryptionHandshake) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) + Expect(handler.handshakePackets.history.Len()).To(BeZero()) + packet := handler.DequeuePacketForRetransmission() + Expect(packet).To(Equal(lostPacket)) }) }) diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index a712c816..fab8ff08 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -66,6 +66,18 @@ func (mr *MockSentPacketHandlerMockRecorder) DequeueProbePacket() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeueProbePacket)) } +// DropPackets mocks base method +func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DropPackets", arg0) +} + +// DropPackets indicates an expected call of DropPackets +func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) +} + // GetAlarmTimeout mocks base method func (m *MockSentPacketHandler) GetAlarmTimeout() time.Time { m.ctrl.T.Helper() @@ -203,18 +215,6 @@ func (mr *MockSentPacketHandlerMockRecorder) SentPacketsAsRetransmission(arg0, a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacketsAsRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacketsAsRetransmission), arg0, arg1) } -// SetHandshakeComplete mocks base method -func (m *MockSentPacketHandler) SetHandshakeComplete() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetHandshakeComplete") -} - -// SetHandshakeComplete indicates an expected call of SetHandshakeComplete -func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeComplete() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeComplete", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeComplete)) -} - // SetMaxAckDelay mocks base method func (m *MockSentPacketHandler) SetMaxAckDelay(arg0 time.Duration) { m.ctrl.T.Helper() diff --git a/session.go b/session.go index 4b37ce77..0728ede3 100644 --- a/session.go +++ b/session.go @@ -485,7 +485,8 @@ func (s *session) handleHandshakeComplete() { // independent from the application protocol. if s.perspective == protocol.PerspectiveServer { s.queueControlFrame(&wire.PingFrame{}) - s.sentPacketHandler.SetHandshakeComplete() + s.sentPacketHandler.DropPackets(protocol.EncryptionInitial) + s.sentPacketHandler.DropPackets(protocol.EncryptionHandshake) } } @@ -646,7 +647,8 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time if s.perspective == protocol.PerspectiveClient { if !s.receivedFirstForwardSecurePacket && packet.encryptionLevel == protocol.Encryption1RTT { s.receivedFirstForwardSecurePacket = true - s.sentPacketHandler.SetHandshakeComplete() + s.sentPacketHandler.DropPackets(protocol.EncryptionInitial) + s.sentPacketHandler.DropPackets(protocol.EncryptionHandshake) } }