diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 4f667ac3f..b1abad0ab 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -45,6 +45,7 @@ type SentPacketHandler interface { type ReceivedPacketHandler interface { ReceivedPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error IgnoreBelow(protocol.PacketNumber) + DropPackets(protocol.EncryptionLevel) GetAlarmTimeout() time.Time GetAckFrame(protocol.EncryptionLevel) *wire.AckFrame diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index b457fb4b0..1035bd12f 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -75,9 +75,25 @@ func (h *receivedPacketHandler) IgnoreBelow(pn protocol.PacketNumber) { h.oneRTTPackets.IgnoreBelow(pn) } +func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { + switch encLevel { + case protocol.EncryptionInitial: + h.initialPackets = nil + case protocol.EncryptionHandshake: + h.handshakePackets = nil + default: + panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) + } +} + func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { - initialAlarm := h.initialPackets.GetAlarmTimeout() - handshakeAlarm := h.handshakePackets.GetAlarmTimeout() + var initialAlarm, handshakeAlarm time.Time + if h.initialPackets != nil { + initialAlarm = h.initialPackets.GetAlarmTimeout() + } + if h.handshakePackets != nil { + handshakeAlarm = h.handshakePackets.GetAlarmTimeout() + } oneRTTAlarm := h.oneRTTPackets.GetAlarmTimeout() return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm) } @@ -86,9 +102,13 @@ func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel) * var ack *wire.AckFrame switch encLevel { case protocol.EncryptionInitial: - ack = h.initialPackets.GetAckFrame() + if h.initialPackets != nil { + ack = h.initialPackets.GetAckFrame() + } case protocol.EncryptionHandshake: - ack = h.handshakePackets.GetAckFrame() + if h.handshakePackets != nil { + ack = h.handshakePackets.GetAckFrame() + } case protocol.Encryption1RTT: return h.oneRTTPackets.GetAckFrame() default: diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index 80a26b7f9..7bcea8e34 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -47,4 +47,24 @@ var _ = Describe("Received Packet Handler", func() { Expect(oneRTTAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5})) Expect(oneRTTAck.DelayTime).To(BeNumerically("~", time.Second, 50*time.Millisecond)) }) + + It("drops Initial packets", func() { + sendTime := time.Now().Add(-time.Second) + Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.GetAckFrame(protocol.EncryptionInitial)).ToNot(BeNil()) + handler.DropPackets(protocol.EncryptionInitial) + Expect(handler.GetAckFrame(protocol.EncryptionInitial)).To(BeNil()) + Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).ToNot(BeNil()) + }) + + It("drops Handshake packets", func() { + sendTime := time.Now().Add(-time.Second) + Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).ToNot(BeNil()) + handler.DropPackets(protocol.EncryptionInitial) + Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).To(BeNil()) + Expect(handler.GetAckFrame(protocol.Encryption1RTT)).ToNot(BeNil()) + }) }) diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go index 4c395f27e..640b725a4 100644 --- a/internal/mocks/ackhandler/received_packet_handler.go +++ b/internal/mocks/ackhandler/received_packet_handler.go @@ -36,6 +36,18 @@ func (m *MockReceivedPacketHandler) EXPECT() *MockReceivedPacketHandlerMockRecor return m.recorder } +// DropPackets mocks base method +func (m *MockReceivedPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DropPackets", arg0) +} + +// DropPackets indicates an expected call of DropPackets +func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockReceivedPacketHandler)(nil).DropPackets), arg0) +} + // GetAckFrame mocks base method func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel) *wire.AckFrame { m.ctrl.T.Helper()