diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index d0bce26e..4ea8a3c0 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -41,9 +41,9 @@ type SentPacketHandler interface { // ReceivedPacketHandler handles ACKs needed to send for incoming packets type ReceivedPacketHandler interface { - ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error + ReceivedPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error IgnoreBelow(protocol.PacketNumber) GetAlarmTimeout() time.Time - GetAckFrame() *wire.AckFrame + GetAckFrame(protocol.EncryptionLevel) *wire.AckFrame } diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go new file mode 100644 index 00000000..1df64ea8 --- /dev/null +++ b/internal/ackhandler/received_packet_handler.go @@ -0,0 +1,98 @@ +package ackhandler + +import ( + "fmt" + "time" + + "github.com/lucas-clemente/quic-go/internal/congestion" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +const ( + // maximum delay that can be applied to an ACK for a retransmittable packet + ackSendDelay = 25 * time.Millisecond + // initial maximum number of retransmittable packets received before sending an ack. + initialRetransmittablePacketsBeforeAck = 2 + // number of retransmittable that an ACK is sent for + retransmittablePacketsBeforeAck = 10 + // 1/5 RTT delay when doing ack decimation + ackDecimationDelay = 1.0 / 4 + // 1/8 RTT delay when doing ack decimation + shortAckDecimationDelay = 1.0 / 8 + // Minimum number of packets received before ack decimation is enabled. + // This intends to avoid the beginning of slow start, when CWNDs may be + // rapidly increasing. + minReceivedBeforeAckDecimation = 100 + // Maximum number of packets to ack immediately after a missing packet for + // fast retransmission to kick in at the sender. This limit is created to + // reduce the number of acks sent that have no benefit for fast retransmission. + // Set to the number of nacks needed for fast retransmit plus one for protection + // against an ack loss + maxPacketsAfterNewMissing = 4 +) + +type receivedPacketHandler struct { + initialPackets *receivedPacketTracker + handshakePackets *receivedPacketTracker + oneRTTPackets *receivedPacketTracker +} + +var _ ReceivedPacketHandler = &receivedPacketHandler{} + +// NewReceivedPacketHandler creates a new receivedPacketHandler +func NewReceivedPacketHandler( + rttStats *congestion.RTTStats, + logger utils.Logger, + version protocol.VersionNumber, +) ReceivedPacketHandler { + return &receivedPacketHandler{ + initialPackets: newReceivedPacketTracker(rttStats, logger, version), + handshakePackets: newReceivedPacketTracker(rttStats, logger, version), + oneRTTPackets: newReceivedPacketTracker(rttStats, logger, version), + } +} + +func (h *receivedPacketHandler) ReceivedPacket( + pn protocol.PacketNumber, + encLevel protocol.EncryptionLevel, + rcvTime time.Time, + shouldInstigateAck bool, +) error { + switch encLevel { + case protocol.EncryptionInitial: + return h.initialPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) + case protocol.EncryptionHandshake: + return h.handshakePackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) + case protocol.Encryption1RTT: + return h.oneRTTPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) + default: + return fmt.Errorf("received packet with unknown encryption level: %s", encLevel) + } +} + +// only to be used with 1-RTT packets +func (h *receivedPacketHandler) IgnoreBelow(pn protocol.PacketNumber) { + h.oneRTTPackets.IgnoreBelow(pn) +} + +func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { + initialAlarm := h.initialPackets.GetAlarmTimeout() + handshakeAlarm := h.handshakePackets.GetAlarmTimeout() + oneRTTAlarm := h.oneRTTPackets.GetAlarmTimeout() + return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm) +} + +func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel) *wire.AckFrame { + switch encLevel { + case protocol.EncryptionInitial: + return h.initialPackets.GetAckFrame() + case protocol.EncryptionHandshake: + return h.handshakePackets.GetAckFrame() + case protocol.Encryption1RTT: + return h.oneRTTPackets.GetAckFrame() + default: + return nil + } +} diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go new file mode 100644 index 00000000..5938a43c --- /dev/null +++ b/internal/ackhandler/received_packet_handler_test.go @@ -0,0 +1,47 @@ +package ackhandler + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/congestion" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Received Packet Handler", func() { + var handler ReceivedPacketHandler + + BeforeEach(func() { + handler = NewReceivedPacketHandler( + &congestion.RTTStats{}, + utils.DefaultLogger, + protocol.VersionWhatever, + ) + }) + + It("generates ACKs for different packet number spaces", func() { + now := time.Now() + Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, now, true)).To(Succeed()) + Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, now, true)).To(Succeed()) + Expect(handler.ReceivedPacket(5, protocol.Encryption1RTT, now, true)).To(Succeed()) + Expect(handler.ReceivedPacket(3, protocol.EncryptionInitial, now, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.EncryptionHandshake, now, true)).To(Succeed()) + Expect(handler.ReceivedPacket(4, protocol.Encryption1RTT, now, true)).To(Succeed()) + initialAck := handler.GetAckFrame(protocol.EncryptionInitial) + Expect(initialAck).ToNot(BeNil()) + Expect(initialAck.AckRanges).To(HaveLen(1)) + Expect(initialAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 2, Largest: 3})) + handshakeAck := handler.GetAckFrame(protocol.EncryptionHandshake) + Expect(handshakeAck).ToNot(BeNil()) + Expect(handshakeAck.AckRanges).To(HaveLen(1)) + Expect(handshakeAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 1, Largest: 2})) + oneRTTAck := handler.GetAckFrame(protocol.Encryption1RTT) + Expect(oneRTTAck).ToNot(BeNil()) + Expect(oneRTTAck.AckRanges).To(HaveLen(1)) + Expect(oneRTTAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5})) + }) +}) diff --git a/internal/ackhandler/received_packet_tracker.go b/internal/ackhandler/received_packet_tracker.go index 6beb5477..b7840490 100644 --- a/internal/ackhandler/received_packet_tracker.go +++ b/internal/ackhandler/received_packet_tracker.go @@ -30,35 +30,11 @@ type receivedPacketTracker struct { version protocol.VersionNumber } -const ( - // maximum delay that can be applied to an ACK for a retransmittable packet - ackSendDelay = 25 * time.Millisecond - // initial maximum number of retransmittable packets received before sending an ack. - initialRetransmittablePacketsBeforeAck = 2 - // number of retransmittable that an ACK is sent for - retransmittablePacketsBeforeAck = 10 - // 1/5 RTT delay when doing ack decimation - ackDecimationDelay = 1.0 / 4 - // 1/8 RTT delay when doing ack decimation - shortAckDecimationDelay = 1.0 / 8 - // Minimum number of packets received before ack decimation is enabled. - // This intends to avoid the beginning of slow start, when CWNDs may be - // rapidly increasing. - minReceivedBeforeAckDecimation = 100 - // Maximum number of packets to ack immediately after a missing packet for - // fast retransmission to kick in at the sender. This limit is created to - // reduce the number of acks sent that have no benefit for fast retransmission. - // Set to the number of nacks needed for fast retransmit plus one for protection - // against an ack loss - maxPacketsAfterNewMissing = 4 -) - -// NewReceivedPacketHandler creates a new receivedPacketHandler -func NewReceivedPacketHandler( +func newReceivedPacketTracker( rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber, -) ReceivedPacketHandler { +) *receivedPacketTracker { return &receivedPacketTracker{ packetHistory: newReceivedPacketHistory(), ackSendDelay: ackSendDelay, diff --git a/internal/ackhandler/received_packet_tracker_test.go b/internal/ackhandler/received_packet_tracker_test.go index bec1738f..225a3662 100644 --- a/internal/ackhandler/received_packet_tracker_test.go +++ b/internal/ackhandler/received_packet_tracker_test.go @@ -20,7 +20,7 @@ var _ = Describe("Received Packet Tracker", func() { BeforeEach(func() { rttStats = &congestion.RTTStats{} - tracker = NewReceivedPacketHandler(rttStats, utils.DefaultLogger, protocol.VersionWhatever).(*receivedPacketTracker) + tracker = newReceivedPacketTracker(rttStats, utils.DefaultLogger, protocol.VersionWhatever) }) Context("accepting packets", func() { diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go index d3e864f6..9e25010c 100644 --- a/internal/mocks/ackhandler/received_packet_handler.go +++ b/internal/mocks/ackhandler/received_packet_handler.go @@ -37,15 +37,15 @@ func (m *MockReceivedPacketHandler) EXPECT() *MockReceivedPacketHandlerMockRecor } // GetAckFrame mocks base method -func (m *MockReceivedPacketHandler) GetAckFrame() *wire.AckFrame { - ret := m.ctrl.Call(m, "GetAckFrame") +func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel) *wire.AckFrame { + ret := m.ctrl.Call(m, "GetAckFrame", arg0) ret0, _ := ret[0].(*wire.AckFrame) return ret0 } // GetAckFrame indicates an expected call of GetAckFrame -func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame)) +func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame), arg0) } // GetAlarmTimeout mocks base method @@ -71,13 +71,13 @@ func (mr *MockReceivedPacketHandlerMockRecorder) IgnoreBelow(arg0 interface{}) * } // ReceivedPacket mocks base method -func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 time.Time, arg2 bool) error { - ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2) +func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 protocol.EncryptionLevel, arg2 time.Time, arg3 bool) error { + ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // ReceivedPacket indicates an expected call of ReceivedPacket -func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2) +func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2, arg3) } diff --git a/internal/utils/minmax.go b/internal/utils/minmax.go index 4394ab04..84cbec7b 100644 --- a/internal/utils/minmax.go +++ b/internal/utils/minmax.go @@ -122,6 +122,18 @@ func MinTime(a, b time.Time) time.Time { return a } +// MinNonZeroTime returns the earlist time that is not time.Time{} +// If both a and b are time.Time{}, it returns time.Time{} +func MinNonZeroTime(a, b time.Time) time.Time { + if a.IsZero() { + return b + } + if b.IsZero() { + return a + } + return MinTime(a, b) +} + // MaxTime returns the later time func MaxTime(a, b time.Time) time.Time { if a.After(b) { diff --git a/internal/utils/minmax_test.go b/internal/utils/minmax_test.go index 95816372..16b1f57a 100644 --- a/internal/utils/minmax_test.go +++ b/internal/utils/minmax_test.go @@ -95,6 +95,16 @@ var _ = Describe("Min / Max", func() { Expect(MinTime(a, b)).To(Equal(a)) Expect(MinTime(b, a)).To(Equal(a)) }) + + It("returns the minium non-zero time", func() { + a := time.Time{} + b := time.Now() + Expect(MinNonZeroTime(time.Time{}, time.Time{})).To(Equal(time.Time{})) + Expect(MinNonZeroTime(a, b)).To(Equal(b)) + Expect(MinNonZeroTime(b, a)).To(Equal(b)) + Expect(MinNonZeroTime(b, b.Add(time.Second))).To(Equal(b)) + Expect(MinNonZeroTime(b.Add(time.Second), b)).To(Equal(b)) + }) }) It("returns the abs time", func() { diff --git a/mock_ack_frame_source_test.go b/mock_ack_frame_source_test.go index d1c44373..f1362124 100644 --- a/mock_ack_frame_source_test.go +++ b/mock_ack_frame_source_test.go @@ -8,6 +8,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" wire "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -35,13 +36,13 @@ func (m *MockAckFrameSource) EXPECT() *MockAckFrameSourceMockRecorder { } // GetAckFrame mocks base method -func (m *MockAckFrameSource) GetAckFrame() *wire.AckFrame { - ret := m.ctrl.Call(m, "GetAckFrame") +func (m *MockAckFrameSource) GetAckFrame(arg0 protocol.EncryptionLevel) *wire.AckFrame { + ret := m.ctrl.Call(m, "GetAckFrame", arg0) ret0, _ := ret[0].(*wire.AckFrame) return ret0 } // GetAckFrame indicates an expected call of GetAckFrame -func (mr *MockAckFrameSourceMockRecorder) GetAckFrame() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetAckFrame)) +func (mr *MockAckFrameSourceMockRecorder) GetAckFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetAckFrame), arg0) } diff --git a/packet_packer.go b/packet_packer.go index aa55262b..93a3bc3b 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -90,7 +90,7 @@ type frameSource interface { } type ackFrameSource interface { - GetAckFrame() *wire.AckFrame + GetAckFrame(protocol.EncryptionLevel) *wire.AckFrame } type packetPacker struct { @@ -155,7 +155,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac } func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { - ack := p.acks.GetAckFrame() + ack := p.acks.GetAckFrame(protocol.Encryption1RTT) if ack == nil { return nil, nil } @@ -285,30 +285,41 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { var s cryptoStream var encLevel protocol.EncryptionLevel - if p.initialStream.HasData() { + + hasData := p.initialStream.HasData() + ack := p.acks.GetAckFrame(protocol.EncryptionInitial) + if hasData || ack != nil { s = p.initialStream encLevel = protocol.EncryptionInitial - } else if p.handshakeStream.HasData() { - s = p.handshakeStream - encLevel = protocol.EncryptionHandshake + } else { + hasData = p.handshakeStream.HasData() + ack = p.acks.GetAckFrame(protocol.EncryptionHandshake) + if hasData || ack != nil { + s = p.handshakeStream + encLevel = protocol.EncryptionHandshake + } } if s == nil { return nil, nil } - hdr := p.getHeader(encLevel) - hdrLen := hdr.GetLength(p.version) sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel) if err != nil { + // The sealer return nil, err } + + hdr := p.getHeader(encLevel) + hdrLen := hdr.GetLength(p.version) var length protocol.ByteCount frames := make([]wire.Frame, 0, 2) - if ack := p.acks.GetAckFrame(); ack != nil { + if ack != nil { frames = append(frames, ack) length += ack.Length(p.version) } - cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length) - frames = append(frames, cf) + if hasData { + cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length) + frames = append(frames, cf) + } return p.writeAndSealPacket(hdr, frames, sealer) } @@ -317,7 +328,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wir var frames []wire.Frame // ACKs need to go first, so that the sentPacketHandler will recognize them - if ack := p.acks.GetAckFrame(); ack != nil { + if ack := p.acks.GetAckFrame(protocol.Encryption1RTT); ack != nil { frames = append(frames, ack) length += ack.Length(p.version) } diff --git a/packet_packer_test.go b/packet_packer_test.go index 1971700e..ede87fd4 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -9,7 +9,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/ackhandler" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/mocks" - "github.com/lucas-clemente/quic-go/internal/mocks/ackhandler" + mockackhandler "github.com/lucas-clemente/quic-go/internal/mocks/ackhandler" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" @@ -174,7 +174,9 @@ var _ = Describe("Packet packer", func() { }), ) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() f := &wire.StreamFrame{Data: []byte{0xde, 0xca, 0xfb, 0xad}} expectAppendStreamFrames(f) @@ -206,14 +208,16 @@ var _ = Describe("Packet packer", func() { Context("packing normal packets", func() { BeforeEach(func() { initialStream.EXPECT().HasData().AnyTimes() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).AnyTimes() handshakeStream.EXPECT().HasData().AnyTimes() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).AnyTimes() }) It("returns nil when no packet is queued", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) // don't expect any calls to PopPacketNumber sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) framer.EXPECT().AppendControlFrames(nil, gomock.Any()) framer.EXPECT().AppendStreamFrames(nil, gomock.Any()) p, err := packer.PackPacket() @@ -225,7 +229,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() f := &wire.StreamFrame{ StreamID: 5, @@ -245,7 +249,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() expectAppendStreamFrames(&wire.StreamFrame{ StreamID: 5, @@ -260,7 +264,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 42, Smallest: 1}}} - ackFramer.EXPECT().GetAckFrame().Return(ack) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) expectAppendControlFrames() expectAppendStreamFrames() @@ -289,7 +293,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) frames := []wire.Frame{&wire.ResetStreamFrame{}, &wire.MaxDataFrame{}} expectAppendControlFrames(frames...) expectAppendStreamFrames() @@ -303,7 +307,7 @@ var _ = Describe("Packet packer", func() { It("accounts for the space consumed by control frames", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) var maxSize protocol.ByteCount gomock.InOrder( framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { @@ -321,7 +325,7 @@ var _ = Describe("Packet packer", func() { Context("packing ACK packets", func() { It("doesn't pack a packet if there's no ACK to send", func() { - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) p, err := packer.MaybePackAckPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) @@ -332,7 +336,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} - ackFramer.EXPECT().GetAckFrame().Return(ack) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) p, err := packer.MaybePackAckPacket() Expect(err).NotTo(HaveOccurred()) Expect(p.frames).To(Equal([]wire.Frame{ack})) @@ -345,7 +349,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() p, err := packer.PackPacket() @@ -360,7 +364,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() p, err := packer.PackPacket() @@ -371,7 +375,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() p, err = packer.PackPacket() @@ -387,7 +391,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) expectAppendControlFrames() expectAppendStreamFrames() - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) @@ -397,7 +401,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(2)) @@ -409,7 +413,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendStreamFrames() expectAppendControlFrames(&wire.MaxDataFrame{}) p, err := packer.PackPacket() @@ -423,7 +427,7 @@ var _ = Describe("Packet packer", func() { It("does not split a STREAM frame with maximum size", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) expectAppendControlFrames() sf := &wire.StreamFrame{ @@ -462,7 +466,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) expectAppendControlFrames() expectAppendStreamFrames(f1, f2, f3) p, err := packer.PackPacket() @@ -651,7 +655,7 @@ var _ = Describe("Packet packer", func() { It("sets the maximum packet size", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer).Times(2) - ackFramer.EXPECT().GetAckFrame().Times(2) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Times(2) var initialMaxPacketSize protocol.ByteCount framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { initialMaxPacketSize = maxLen @@ -676,7 +680,7 @@ var _ = Describe("Packet packer", func() { It("doesn't increase the max packet size", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer).Times(2) - ackFramer.EXPECT().GetAckFrame().Times(2) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Times(2) var initialMaxPacketSize protocol.ByteCount framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { initialMaxPacketSize = maxLen @@ -708,7 +712,7 @@ var _ = Describe("Packet packer", func() { Offset: 0x1337, Data: []byte("foobar"), } - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) initialStream.EXPECT().HasData().Return(true) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil) @@ -722,7 +726,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionHandshake).Return(sealer, nil) - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) initialStream.EXPECT().HasData() handshakeStream.EXPECT().HasData().Return(true) handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { @@ -740,12 +745,38 @@ var _ = Describe("Packet packer", func() { checkLength(p.raw) }) + It("sends a Initial packet containing only an ACK", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).Return(ack) + initialStream.EXPECT().HasData() + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil) + pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(Equal([]wire.Frame{ack})) + }) + + It("sends a Handshake packet containing only an ACK", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack) + initialStream.EXPECT().HasData() + handshakeStream.EXPECT().HasData() + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionHandshake).Return(sealer, nil) + pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(Equal([]wire.Frame{ack})) + }) + It("pads Initial packets to the required minimum packet size", func() { f := &wire.CryptoFrame{Data: []byte("foobar")} pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil) - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) initialStream.EXPECT().HasData().Return(true) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.perspective = protocol.PerspectiveClient @@ -767,7 +798,9 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen1) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) initialStream.EXPECT().HasData() handshakeStream.EXPECT().HasData() framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()) @@ -798,7 +831,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil) - ackFramer.EXPECT().GetAckFrame() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) initialStream.EXPECT().HasData().Return(true) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(&wire.CryptoFrame{ Data: []byte("foobar"), @@ -815,7 +848,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil) - ackFramer.EXPECT().GetAckFrame().Return(ack) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).Return(ack) initialStream.EXPECT().HasData().Return(true) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.version = protocol.VersionTLS diff --git a/session.go b/session.go index 69d2cf7a..899132ac 100644 --- a/session.go +++ b/session.go @@ -566,7 +566,7 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time } } - if err := s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, rcvTime, isRetransmittable); err != nil { + if err := s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, packet.encryptionLevel, rcvTime, isRetransmittable); err != nil { return err } return nil diff --git a/session_test.go b/session_test.go index f0bb0796..6617609f 100644 --- a/session_test.go +++ b/session_test.go @@ -495,12 +495,13 @@ var _ = Describe("Session", func() { } rcvTime := time.Now().Add(-10 * time.Second) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ - packetNumber: 0x1337, - hdr: hdr, - data: []byte{0}, // one PADDING frame + packetNumber: 0x1337, + encryptionLevel: protocol.EncryptionInitial, + hdr: hdr, + data: []byte{0}, // one PADDING frame }, nil) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, false) + rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionInitial, rcvTime, false) sess.receivedPacketHandler = rph Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ rcvTime: rcvTime, @@ -518,12 +519,13 @@ var _ = Describe("Session", func() { buf := &bytes.Buffer{} Expect((&wire.PingFrame{}).Write(buf, sess.version)).To(Succeed()) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ - packetNumber: 0x1337, - hdr: hdr, - data: buf.Bytes(), + packetNumber: 0x1337, + encryptionLevel: protocol.EncryptionHandshake, + hdr: hdr, + data: buf.Bytes(), }, nil) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, true) + rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionHandshake, rcvTime, true) sess.receivedPacketHandler = rph Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ rcvTime: rcvTime, @@ -583,8 +585,9 @@ var _ = Describe("Session", func() { PacketNumberLen: protocol.PacketNumberLen1, } unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ - hdr: hdr, - data: []byte{0}, // one PADDING frame + encryptionLevel: protocol.Encryption1RTT, + hdr: hdr, + data: []byte{0}, // one PADDING frame }, nil).Times(2) Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue()) Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue()) @@ -610,8 +613,9 @@ var _ = Describe("Session", func() { // Send one packet, which might change the connection ID. // only EXPECT one call to the unpacker unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ - hdr: &wire.ExtendedHeader{Header: *hdr}, - data: []byte{0}, // one PADDING frame + encryptionLevel: protocol.Encryption1RTT, + hdr: &wire.ExtendedHeader{Header: *hdr}, + data: []byte{0}, // one PADDING frame }, nil) Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ hdr: hdr, @@ -632,8 +636,9 @@ var _ = Describe("Session", func() { Context("updating the remote address", func() { It("doesn't support connection migration", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ - hdr: &wire.ExtendedHeader{}, - data: []byte{0}, // one PADDING frame + encryptionLevel: protocol.Encryption1RTT, + hdr: &wire.ExtendedHeader{}, + data: []byte{0}, // one PADDING frame }, nil) origAddr := sess.conn.(*mockConnection).remoteAddr remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} @@ -662,8 +667,7 @@ var _ = Describe("Session", func() { It("sends packets", func() { packer.EXPECT().PackPacket().Return(getPacket(1), nil) - err := sess.receivedPacketHandler.ReceivedPacket(0x035e, time.Now(), true) - Expect(err).ToNot(HaveOccurred()) + Expect(sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.Encryption1RTT, time.Now(), true)).To(Succeed()) sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(sent).To(BeTrue()) @@ -671,8 +675,7 @@ var _ = Describe("Session", func() { It("doesn't send packets if there's nothing to send", func() { packer.EXPECT().PackPacket().Return(getPacket(2), nil) - err := sess.receivedPacketHandler.ReceivedPacket(0x035e, time.Now(), true) - Expect(err).ToNot(HaveOccurred()) + Expect(sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.Encryption1RTT, time.Now(), true)).To(Succeed()) sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(sent).To(BeTrue()) @@ -1408,8 +1411,9 @@ var _ = Describe("Client Session", func() { unpacker := NewMockUnpacker(mockCtrl) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, data []byte) (*unpackedPacket, error) { return &unpackedPacket{ - hdr: &wire.ExtendedHeader{Header: *hdr}, - data: []byte{0}, // one PADDING frame + encryptionLevel: protocol.Encryption1RTT, + hdr: &wire.ExtendedHeader{Header: *hdr}, + data: []byte{0}, // one PADDING frame }, nil }) sess.unpacker = unpacker