diff --git a/Changelog.md b/Changelog.md index 8f65a9d1..839aec33 100644 --- a/Changelog.md +++ b/Changelog.md @@ -2,6 +2,8 @@ ## v0.6.1 (unreleased) +- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored. + ## v0.6.0 (2017-12-12) - Add support for QUIC 39, drop support for QUIC 35 - 37 diff --git a/ackhandler/interfaces.go b/ackhandler/interfaces.go index 7b68faad..914ee146 100644 --- a/ackhandler/interfaces.go +++ b/ackhandler/interfaces.go @@ -16,6 +16,7 @@ type SentPacketHandler interface { SendingAllowed() bool GetStopWaitingFrame(force bool) *wire.StopWaitingFrame + GetLowestPacketNotConfirmedAcked() protocol.PacketNumber ShouldSendRetransmittablePacket() bool DequeuePacketForRetransmission() (packet *Packet) GetLeastUnacked() protocol.PacketNumber diff --git a/ackhandler/packet.go b/ackhandler/packet.go index 9c4ee30b..e4213a0b 100644 --- a/ackhandler/packet.go +++ b/ackhandler/packet.go @@ -15,7 +15,8 @@ type Packet struct { Length protocol.ByteCount EncryptionLevel protocol.EncryptionLevel - SendTime time.Time + largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK + sendTime time.Time } // GetFramesForRetransmission gets all the frames for retransmission diff --git a/ackhandler/sent_packet_handler.go b/ackhandler/sent_packet_handler.go index 4fe66811..b30a78f7 100644 --- a/ackhandler/sent_packet_handler.go +++ b/ackhandler/sent_packet_handler.go @@ -40,6 +40,10 @@ type sentPacketHandler struct { largestAcked protocol.PacketNumber largestReceivedPacketWithAck protocol.PacketNumber + // lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived + // example: we send an ACK for packets 90-100 with packet number 20 + // once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101 + lowestPacketNotConfirmedAcked protocol.PacketNumber packetHistory *PacketList stopWaitingManager stopWaitingManager @@ -114,11 +118,19 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { h.lastSentPacketNumber = packet.PacketNumber now := time.Now() + var largestAcked protocol.PacketNumber + if len(packet.Frames) > 0 { + if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok { + largestAcked = ackFrame.LargestAcked + } + } + packet.Frames = stripNonRetransmittableFrames(packet.Frames) isRetransmittable := len(packet.Frames) != 0 if isRetransmittable { - packet.SendTime = now + packet.sendTime = now + packet.largestAcked = largestAcked h.bytesInFlight += packet.Length h.packetHistory.PushBack(*packet) h.numNonRetransmittablePackets = 0 @@ -146,14 +158,12 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe // duplicate or out-of-order ACK // if withPacketNumber <= h.largestReceivedPacketWithAck && withPacketNumber != 0 { if withPacketNumber <= h.largestReceivedPacketWithAck { - utils.Debugf("ignoring ack because duplicate") return ErrDuplicateOrOutOfOrderAck } h.largestReceivedPacketWithAck = withPacketNumber // ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK) if ackFrame.LargestAcked < h.lowestUnacked() { - utils.Debugf("ignoring ack because repeated") return nil } h.largestAcked = ackFrame.LargestAcked @@ -178,6 +188,12 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe if encLevel < p.Value.EncryptionLevel { return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel) } + // largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0 + // It is safe to ignore the corner case of packets that just acked packet 0, because + // the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send. + if p.Value.largestAcked != 0 { + h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.Value.largestAcked+1) + } h.onPacketAcked(p) h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight) } @@ -192,6 +208,10 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe return nil } +func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { + return h.lowestPacketNotConfirmedAcked +} + func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) { var ackedPackets []*PacketElement ackRangeIndex := 0 @@ -233,7 +253,7 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a for el := h.packetHistory.Front(); el != nil; el = el.Next() { packet := el.Value if packet.PacketNumber == largestAcked { - h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now()) + h.rttStats.UpdateRTT(rcvTime.Sub(packet.sendTime), ackDelay, time.Now()) return true } // Packets are sorted by number, so we can stop searching @@ -278,7 +298,7 @@ func (h *sentPacketHandler) detectLostPackets() { break } - timeSinceSent := now.Sub(packet.SendTime) + timeSinceSent := now.Sub(packet.sendTime) if timeSinceSent > delayUntilLost { lostPackets = append(lostPackets, el) } else if h.lossTime.IsZero() { diff --git a/ackhandler/sent_packet_handler_test.go b/ackhandler/sent_packet_handler_test.go index 6648d552..f44b0285 100644 --- a/ackhandler/sent_packet_handler_test.go +++ b/ackhandler/sent_packet_handler_test.go @@ -143,7 +143,7 @@ var _ = Describe("SentPacketHandler", func() { packet := Packet{PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1} err := handler.SentPacket(&packet) Expect(err).ToNot(HaveOccurred()) - Expect(handler.packetHistory.Front().Value.SendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1)) + Expect(handler.packetHistory.Front().Value.sendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1)) }) It("does not store non-retransmittable packets", func() { @@ -553,9 +553,9 @@ var _ = Describe("SentPacketHandler", func() { It("computes the RTT", func() { now := time.Now() // First, fake the sent times of the first, second and last packet - getPacketElement(1).Value.SendTime = now.Add(-10 * time.Minute) - getPacketElement(2).Value.SendTime = now.Add(-5 * time.Minute) - getPacketElement(6).Value.SendTime = now.Add(-1 * time.Minute) + getPacketElement(1).Value.sendTime = now.Add(-10 * time.Minute) + getPacketElement(2).Value.sendTime = now.Add(-5 * time.Minute) + getPacketElement(6).Value.sendTime = now.Add(-1 * time.Minute) // Now, check that the proper times are used when calculating the deltas err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1}, 1, protocol.EncryptionUnencrypted, time.Now()) Expect(err).NotTo(HaveOccurred()) @@ -570,12 +570,50 @@ var _ = Describe("SentPacketHandler", func() { It("uses the DelayTime in the ack frame", func() { now := time.Now() - getPacketElement(1).Value.SendTime = now.Add(-10 * time.Minute) + getPacketElement(1).Value.sendTime = now.Add(-10 * time.Minute) err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 1, DelayTime: 5 * time.Minute}, 1, protocol.EncryptionUnencrypted, time.Now()) Expect(err).NotTo(HaveOccurred()) Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second)) }) }) + + Context("determinining, which ACKs we have received an ACK for", func() { + BeforeEach(func() { + morePackets := []*Packet{ + &Packet{PacketNumber: 13, Frames: []wire.Frame{&wire.AckFrame{LowestAcked: 80, LargestAcked: 100}, &streamFrame}, Length: 1}, + &Packet{PacketNumber: 14, Frames: []wire.Frame{&wire.AckFrame{LowestAcked: 50, LargestAcked: 200}, &streamFrame}, Length: 1}, + &Packet{PacketNumber: 15, Frames: []wire.Frame{&streamFrame}, Length: 1}, + } + for _, packet := range morePackets { + err := handler.SentPacket(packet) + Expect(err).NotTo(HaveOccurred()) + } + }) + + It("determines which ACK we have received an ACK for", func() { + err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 15, LowestAcked: 12}, 1, protocol.EncryptionUnencrypted, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) + }) + + It("doesn't do anything when the acked packet didn't contain an ACK", func() { + err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 13, LowestAcked: 13}, 1, protocol.EncryptionUnencrypted, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101))) + err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 15, LowestAcked: 15}, 2, protocol.EncryptionUnencrypted, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101))) + }) + + It("doesn't decrease the value", func() { + err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 14, LowestAcked: 14}, 1, protocol.EncryptionUnencrypted, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) + err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 13, LowestAcked: 13}, 2, protocol.EncryptionUnencrypted, time.Now()) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201))) + }) + }) }) Context("Retransmission handling", func() { @@ -606,7 +644,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("dequeues a packet for retransmission", func() { - getPacketElement(1).Value.SendTime = time.Now().Add(-time.Hour) + getPacketElement(1).Value.sendTime = time.Now().Add(-time.Hour) handler.OnAlarm() Expect(getPacketElement(1)).To(BeNil()) Expect(handler.retransmissionQueue).To(HaveLen(1)) @@ -662,7 +700,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).NotTo(HaveOccurred()) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) - handler.packetHistory.Front().Value.SendTime = time.Now().Add(-time.Hour) + handler.packetHistory.Front().Value.sendTime = time.Now().Add(-time.Hour) handler.OnAlarm() Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(0))) @@ -799,7 +837,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.lossTime.Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute)) Expect(handler.GetAlarmTimeout().Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute)) - handler.packetHistory.Front().Value.SendTime = time.Now().Add(-2 * time.Hour) + handler.packetHistory.Front().Value.sendTime = time.Now().Add(-2 * time.Hour) handler.OnAlarm() Expect(handler.DequeuePacketForRetransmission()).NotTo(BeNil()) }) diff --git a/internal/mocks/gen.go b/internal/mocks/gen.go index 0f53b7a2..33df5c99 100644 --- a/internal/mocks/gen.go +++ b/internal/mocks/gen.go @@ -3,6 +3,8 @@ package mocks //go:generate sh -c "./mockgen_internal.sh mockhandshake handshake/mint_tls.go github.com/lucas-clemente/quic-go/internal/handshake MintTLS" //go:generate sh -c "./mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler" //go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController" +//go:generate sh -c "./mockgen_internal.sh mocks sent_packet_handler.go github.com/lucas-clemente/quic-go/ackhandler SentPacketHandler" +//go:generate sh -c "./mockgen_internal.sh mocks received_packet_handler.go github.com/lucas-clemente/quic-go/ackhandler ReceivedPacketHandler" //go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController" //go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD" //go:generate sh -c "./mockgen_stream.sh mocks stream.go github.com/lucas-clemente/quic-go StreamI" diff --git a/internal/mocks/received_packet_handler.go b/internal/mocks/received_packet_handler.go new file mode 100644 index 00000000..a47424c7 --- /dev/null +++ b/internal/mocks/received_packet_handler.go @@ -0,0 +1,82 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/ackhandler (interfaces: ReceivedPacketHandler) + +package mocks + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockReceivedPacketHandler is a mock of ReceivedPacketHandler interface +type MockReceivedPacketHandler struct { + ctrl *gomock.Controller + recorder *MockReceivedPacketHandlerMockRecorder +} + +// MockReceivedPacketHandlerMockRecorder is the mock recorder for MockReceivedPacketHandler +type MockReceivedPacketHandlerMockRecorder struct { + mock *MockReceivedPacketHandler +} + +// NewMockReceivedPacketHandler creates a new mock instance +func NewMockReceivedPacketHandler(ctrl *gomock.Controller) *MockReceivedPacketHandler { + mock := &MockReceivedPacketHandler{ctrl: ctrl} + mock.recorder = &MockReceivedPacketHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (_m *MockReceivedPacketHandler) EXPECT() *MockReceivedPacketHandlerMockRecorder { + return _m.recorder +} + +// GetAckFrame mocks base method +func (_m *MockReceivedPacketHandler) GetAckFrame() *wire.AckFrame { + ret := _m.ctrl.Call(_m, "GetAckFrame") + ret0, _ := ret[0].(*wire.AckFrame) + return ret0 +} + +// GetAckFrame indicates an expected call of GetAckFrame +func (_mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame)) +} + +// GetAlarmTimeout mocks base method +func (_m *MockReceivedPacketHandler) GetAlarmTimeout() time.Time { + ret := _m.ctrl.Call(_m, "GetAlarmTimeout") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// GetAlarmTimeout indicates an expected call of GetAlarmTimeout +func (_mr *MockReceivedPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAlarmTimeout)) +} + +// IgnoreBelow mocks base method +func (_m *MockReceivedPacketHandler) IgnoreBelow(_param0 protocol.PacketNumber) { + _m.ctrl.Call(_m, "IgnoreBelow", _param0) +} + +// IgnoreBelow indicates an expected call of IgnoreBelow +func (_mr *MockReceivedPacketHandlerMockRecorder) IgnoreBelow(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "IgnoreBelow", reflect.TypeOf((*MockReceivedPacketHandler)(nil).IgnoreBelow), arg0) +} + +// ReceivedPacket mocks base method +func (_m *MockReceivedPacketHandler) ReceivedPacket(_param0 protocol.PacketNumber, _param1 bool) error { + ret := _m.ctrl.Call(_m, "ReceivedPacket", _param0, _param1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReceivedPacket indicates an expected call of ReceivedPacket +func (_mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1) +} diff --git a/internal/mocks/sent_packet_handler.go b/internal/mocks/sent_packet_handler.go new file mode 100644 index 00000000..13eb2da2 --- /dev/null +++ b/internal/mocks/sent_packet_handler.go @@ -0,0 +1,165 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/ackhandler (interfaces: SentPacketHandler) + +package mocks + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + ackhandler "github.com/lucas-clemente/quic-go/ackhandler" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockSentPacketHandler is a mock of SentPacketHandler interface +type MockSentPacketHandler struct { + ctrl *gomock.Controller + recorder *MockSentPacketHandlerMockRecorder +} + +// MockSentPacketHandlerMockRecorder is the mock recorder for MockSentPacketHandler +type MockSentPacketHandlerMockRecorder struct { + mock *MockSentPacketHandler +} + +// NewMockSentPacketHandler creates a new mock instance +func NewMockSentPacketHandler(ctrl *gomock.Controller) *MockSentPacketHandler { + mock := &MockSentPacketHandler{ctrl: ctrl} + mock.recorder = &MockSentPacketHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (_m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder { + return _m.recorder +} + +// DequeuePacketForRetransmission mocks base method +func (_m *MockSentPacketHandler) DequeuePacketForRetransmission() *ackhandler.Packet { + ret := _m.ctrl.Call(_m, "DequeuePacketForRetransmission") + ret0, _ := ret[0].(*ackhandler.Packet) + return ret0 +} + +// DequeuePacketForRetransmission indicates an expected call of DequeuePacketForRetransmission +func (_mr *MockSentPacketHandlerMockRecorder) DequeuePacketForRetransmission() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "DequeuePacketForRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeuePacketForRetransmission)) +} + +// GetAlarmTimeout mocks base method +func (_m *MockSentPacketHandler) GetAlarmTimeout() time.Time { + ret := _m.ctrl.Call(_m, "GetAlarmTimeout") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// GetAlarmTimeout indicates an expected call of GetAlarmTimeout +func (_mr *MockSentPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetAlarmTimeout)) +} + +// GetLeastUnacked mocks base method +func (_m *MockSentPacketHandler) GetLeastUnacked() protocol.PacketNumber { + ret := _m.ctrl.Call(_m, "GetLeastUnacked") + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// GetLeastUnacked indicates an expected call of GetLeastUnacked +func (_mr *MockSentPacketHandlerMockRecorder) GetLeastUnacked() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetLeastUnacked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLeastUnacked)) +} + +// GetLowestPacketNotConfirmedAcked mocks base method +func (_m *MockSentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { + ret := _m.ctrl.Call(_m, "GetLowestPacketNotConfirmedAcked") + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// GetLowestPacketNotConfirmedAcked indicates an expected call of GetLowestPacketNotConfirmedAcked +func (_mr *MockSentPacketHandlerMockRecorder) GetLowestPacketNotConfirmedAcked() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLowestPacketNotConfirmedAcked)) +} + +// GetStopWaitingFrame mocks base method +func (_m *MockSentPacketHandler) GetStopWaitingFrame(_param0 bool) *wire.StopWaitingFrame { + ret := _m.ctrl.Call(_m, "GetStopWaitingFrame", _param0) + ret0, _ := ret[0].(*wire.StopWaitingFrame) + return ret0 +} + +// GetStopWaitingFrame indicates an expected call of GetStopWaitingFrame +func (_mr *MockSentPacketHandlerMockRecorder) GetStopWaitingFrame(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetStopWaitingFrame", reflect.TypeOf((*MockSentPacketHandler)(nil).GetStopWaitingFrame), arg0) +} + +// OnAlarm mocks base method +func (_m *MockSentPacketHandler) OnAlarm() { + _m.ctrl.Call(_m, "OnAlarm") +} + +// OnAlarm indicates an expected call of OnAlarm +func (_mr *MockSentPacketHandlerMockRecorder) OnAlarm() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "OnAlarm", reflect.TypeOf((*MockSentPacketHandler)(nil).OnAlarm)) +} + +// ReceivedAck mocks base method +func (_m *MockSentPacketHandler) ReceivedAck(_param0 *wire.AckFrame, _param1 protocol.PacketNumber, _param2 protocol.EncryptionLevel, _param3 time.Time) error { + ret := _m.ctrl.Call(_m, "ReceivedAck", _param0, _param1, _param2, _param3) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReceivedAck indicates an expected call of ReceivedAck +func (_mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2, arg3) +} + +// SendingAllowed mocks base method +func (_m *MockSentPacketHandler) SendingAllowed() bool { + ret := _m.ctrl.Call(_m, "SendingAllowed") + ret0, _ := ret[0].(bool) + return ret0 +} + +// SendingAllowed indicates an expected call of SendingAllowed +func (_mr *MockSentPacketHandlerMockRecorder) SendingAllowed() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SendingAllowed", reflect.TypeOf((*MockSentPacketHandler)(nil).SendingAllowed)) +} + +// SentPacket mocks base method +func (_m *MockSentPacketHandler) SentPacket(_param0 *ackhandler.Packet) error { + ret := _m.ctrl.Call(_m, "SentPacket", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SentPacket indicates an expected call of SentPacket +func (_mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0) +} + +// SetHandshakeComplete mocks base method +func (_m *MockSentPacketHandler) SetHandshakeComplete() { + _m.ctrl.Call(_m, "SetHandshakeComplete") +} + +// SetHandshakeComplete indicates an expected call of SetHandshakeComplete +func (_mr *MockSentPacketHandlerMockRecorder) SetHandshakeComplete() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetHandshakeComplete", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeComplete)) +} + +// ShouldSendRetransmittablePacket mocks base method +func (_m *MockSentPacketHandler) ShouldSendRetransmittablePacket() bool { + ret := _m.ctrl.Call(_m, "ShouldSendRetransmittablePacket") + ret0, _ := ret[0].(bool) + return ret0 +} + +// ShouldSendRetransmittablePacket indicates an expected call of ShouldSendRetransmittablePacket +func (_mr *MockSentPacketHandlerMockRecorder) ShouldSendRetransmittablePacket() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ShouldSendRetransmittablePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).ShouldSendRetransmittablePacket)) +} diff --git a/packet_packer.go b/packet_packer.go index 83ea334f..ff4c0919 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -212,15 +212,15 @@ func (p *packetPacker) composeNextPacket( var payloadFrames []wire.Frame // STOP_WAITING and ACK will always fit - if p.stopWaiting != nil { - payloadFrames = append(payloadFrames, p.stopWaiting) - payloadLength += p.stopWaiting.MinLength(p.version) - } - if p.ackFrame != nil { + if p.ackFrame != nil { // ACKs need to go first, so that the sentPacketHandler will recognize them payloadFrames = append(payloadFrames, p.ackFrame) l := p.ackFrame.MinLength(p.version) payloadLength += l } + if p.stopWaiting != nil { + payloadFrames = append(payloadFrames, p.stopWaiting) + payloadLength += p.stopWaiting.MinLength(p.version) + } p.controlFrameMutex.Lock() for len(p.controlFrames) > 0 { diff --git a/session.go b/session.go index 316eafba..4b732c1f 100644 --- a/session.go +++ b/session.go @@ -523,8 +523,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase)) case *wire.GoawayFrame: err = errors.New("unimplemented: handling GOAWAY frames") - case *wire.StopWaitingFrame: - s.receivedPacketHandler.IgnoreBelow(frame.LeastUnacked) + case *wire.StopWaitingFrame: // ignore STOP_WAITINGs case *wire.RstStreamFrame: err = s.handleRstStreamFrame(frame) case *wire.MaxDataFrame: @@ -616,7 +615,11 @@ func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { } func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { - return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime) + if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil { + return err + } + s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked()) + return nil } func (s *session) closeLocal(e error) { diff --git a/session_test.go b/session_test.go index cf1768ea..9e6ecece 100644 --- a/session_test.go +++ b/session_test.go @@ -70,73 +70,6 @@ func (m *mockUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte }, nil } -type mockSentPacketHandler struct { - retransmissionQueue []*ackhandler.Packet - sentPackets []*ackhandler.Packet - congestionLimited bool - requestedStopWaiting bool - shouldSendRetransmittablePacket bool -} - -func (h *mockSentPacketHandler) SentPacket(packet *ackhandler.Packet) error { - h.sentPackets = append(h.sentPackets, packet) - return nil -} - -func (h *mockSentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error { - return nil -} -func (h *mockSentPacketHandler) SetHandshakeComplete() {} -func (h *mockSentPacketHandler) GetLeastUnacked() protocol.PacketNumber { return 1 } -func (h *mockSentPacketHandler) GetAlarmTimeout() time.Time { panic("not implemented") } -func (h *mockSentPacketHandler) OnAlarm() { panic("not implemented") } -func (h *mockSentPacketHandler) SendingAllowed() bool { return !h.congestionLimited } -func (h *mockSentPacketHandler) ShouldSendRetransmittablePacket() bool { - b := h.shouldSendRetransmittablePacket - h.shouldSendRetransmittablePacket = false - return b -} - -func (h *mockSentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame { - h.requestedStopWaiting = true - return &wire.StopWaitingFrame{LeastUnacked: 0x1337} -} - -func (h *mockSentPacketHandler) DequeuePacketForRetransmission() *ackhandler.Packet { - if len(h.retransmissionQueue) > 0 { - packet := h.retransmissionQueue[0] - h.retransmissionQueue = h.retransmissionQueue[1:] - return packet - } - return nil -} - -func newMockSentPacketHandler() ackhandler.SentPacketHandler { - return &mockSentPacketHandler{} -} - -var _ ackhandler.SentPacketHandler = &mockSentPacketHandler{} - -type mockReceivedPacketHandler struct { - nextAckFrame *wire.AckFrame - ackAlarm time.Time -} - -func (m *mockReceivedPacketHandler) GetAckFrame() *wire.AckFrame { - f := m.nextAckFrame - m.nextAckFrame = nil - return f -} -func (m *mockReceivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error { - panic("not implemented") -} -func (m *mockReceivedPacketHandler) IgnoreBelow(protocol.PacketNumber) { - panic("not implemented") -} -func (m *mockReceivedPacketHandler) GetAlarmTimeout() time.Time { return m.ackAlarm } - -var _ ackhandler.ReceivedPacketHandler = &mockReceivedPacketHandler{} - func areSessionsRunning() bool { var b bytes.Buffer pprof.Lookup("goroutine").WriteTo(&b, 1) @@ -341,6 +274,31 @@ var _ = Describe("Session", func() { }) }) + Context("handling ACK frames", func() { + It("informs the SentPacketHandler about ACKs", func() { + f := &wire.AckFrame{LargestAcked: 3, LowestAcked: 2} + sph := mocks.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().ReceivedAck(f, protocol.PacketNumber(42), protocol.EncryptionSecure, gomock.Any()) + sph.EXPECT().GetLowestPacketNotConfirmedAcked() + sess.sentPacketHandler = sph + sess.lastRcvdPacketNumber = 42 + err := sess.handleAckFrame(f, protocol.EncryptionSecure) + Expect(err).ToNot(HaveOccurred()) + }) + + It("tells the ReceivedPacketHandler to ignore low ranges", func() { + sph := mocks.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().ReceivedAck(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + sph.EXPECT().GetLowestPacketNotConfirmedAcked().Return(protocol.PacketNumber(0x42)) + sess.sentPacketHandler = sph + rph := mocks.NewMockReceivedPacketHandler(mockCtrl) + rph.EXPECT().IgnoreBelow(protocol.PacketNumber(0x42)) + sess.receivedPacketHandler = rph + err := sess.handleAckFrame(&wire.AckFrame{LargestAcked: 3, LowestAcked: 2}, protocol.EncryptionUnencrypted) + Expect(err).ToNot(HaveOccurred()) + }) + }) + Context("handling RST_STREAM frames", func() { It("closes the streams for writing", func() { f := &wire.RstStreamFrame{ @@ -803,7 +761,12 @@ var _ = Describe("Session", func() { }) It("sends ACK frames when congestion limited", func() { - sess.sentPacketHandler = &mockSentPacketHandler{congestionLimited: true} + sph := mocks.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().SendingAllowed().Return(false) + sph.EXPECT().GetStopWaitingFrame(false) + sph.EXPECT().SentPacket(gomock.Any()) + sess.sentPacketHandler = sph sess.packer.packetNumberGenerator.next = 0x1338 packetNumber := protocol.PacketNumber(0x035e) sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) @@ -814,12 +777,18 @@ var _ = Describe("Session", func() { }) It("sends a retransmittable packet when required by the SentPacketHandler", func() { - sess.sentPacketHandler = &mockSentPacketHandler{shouldSendRetransmittablePacket: true} sess.packer.QueueControlFrame(&wire.AckFrame{LargestAcked: 1000}) + sph := mocks.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().SendingAllowed().Return(true) + sph.EXPECT().SendingAllowed().Return(false) + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().ShouldSendRetransmittablePacket().Return(true) + sph.EXPECT().SentPacket(gomock.Any()) + sess.sentPacketHandler = sph err := sess.sendPacket() Expect(err).ToNot(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) - Expect(sess.sentPacketHandler.(*mockSentPacketHandler).sentPackets[0].Frames).To(ContainElement(&wire.PingFrame{})) }) It("sends public reset", func() { @@ -830,36 +799,48 @@ var _ = Describe("Session", func() { }) It("informs the SentPacketHandler about sent packets", func() { - sess.sentPacketHandler = newMockSentPacketHandler() - sess.packer.packetNumberGenerator.next = 0x1337 + 9 - sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} - f := &wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), } + var sentPacket *ackhandler.Packet + sph := mocks.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().GetStopWaitingFrame(gomock.Any()) + sph.EXPECT().SendingAllowed().Return(true) + sph.EXPECT().SendingAllowed().Return(false) + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().ShouldSendRetransmittablePacket() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + sentPacket = p + }) + sess.sentPacketHandler = sph + sess.packer.packetNumberGenerator.next = 0x1337 + 9 + sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} + sess.streamFramer.AddFrameForRetransmission(f) _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) err = sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) - sentPackets := sess.sentPacketHandler.(*mockSentPacketHandler).sentPackets - Expect(sentPackets).To(HaveLen(1)) - Expect(sentPackets[0].Frames).To(ContainElement(f)) - Expect(sentPackets[0].EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) - Expect(mconn.written).To(HaveLen(1)) - Expect(sentPackets[0].Length).To(BeEquivalentTo(len(<-mconn.written))) + Expect(sentPacket.PacketNumber).To(Equal(protocol.PacketNumber(0x1337 + 9))) + Expect(sentPacket.Frames).To(ContainElement(f)) + Expect(sentPacket.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) + Expect(sentPacket.Length).To(BeEquivalentTo(len(<-mconn.written))) }) }) Context("retransmissions", func() { - var sph *mockSentPacketHandler + var sph *mocks.MockSentPacketHandler BeforeEach(func() { - // a StopWaitingFrame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet + // a STOP_WAITING frame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet sess.packer.packetNumberGenerator.next = 0x1337 + 10 sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends - sph = newMockSentPacketHandler().(*mockSentPacketHandler) + sph = mocks.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().SendingAllowed().Return(true) + sph.EXPECT().ShouldSendRetransmittablePacket() sess.sentPacketHandler = sph sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} }) @@ -867,45 +848,36 @@ var _ = Describe("Session", func() { Context("for handshake packets", func() { It("retransmits an unencrypted packet", func() { sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} - sph.retransmissionQueue = []*ackhandler.Packet{{ - Frames: []wire.Frame{sf}, - EncryptionLevel: protocol.EncryptionUnencrypted, - }} + var sentPacket *ackhandler.Packet + sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{LeastUnacked: 0x1337}) + sph.EXPECT().DequeuePacketForRetransmission().Return( + &ackhandler.Packet{ + Frames: []wire.Frame{sf}, + EncryptionLevel: protocol.EncryptionUnencrypted, + }) + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + sentPacket = p + }) err := sess.sendPacket() Expect(err).ToNot(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) - sentPackets := sph.sentPackets - Expect(sentPackets).To(HaveLen(1)) - Expect(sentPackets[0].EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) - Expect(sentPackets[0].Frames).To(HaveLen(2)) - Expect(sentPackets[0].Frames[1]).To(Equal(sf)) - swf := sentPackets[0].Frames[0].(*wire.StopWaitingFrame) + Expect(sentPacket.EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) + Expect(sentPacket.Frames).To(HaveLen(2)) + Expect(sentPacket.Frames[1]).To(Equal(sf)) + swf := sentPacket.Frames[0].(*wire.StopWaitingFrame) Expect(swf.LeastUnacked).To(Equal(protocol.PacketNumber(0x1337))) }) - It("retransmit a packet encrypted with the initial encryption", func() { - sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} - sph.retransmissionQueue = []*ackhandler.Packet{{ - Frames: []wire.Frame{sf}, - EncryptionLevel: protocol.EncryptionSecure, - }} - err := sess.sendPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(mconn.written).To(HaveLen(1)) - sentPackets := sph.sentPackets - Expect(sentPackets).To(HaveLen(1)) - Expect(sentPackets[0].EncryptionLevel).To(Equal(protocol.EncryptionSecure)) - Expect(sentPackets[0].Frames).To(HaveLen(2)) - Expect(sentPackets[0].Frames).To(ContainElement(sf)) - }) - It("doesn't retransmit handshake packets when the handshake is complete", func() { sess.handshakeComplete = true sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} - sph.retransmissionQueue = []*ackhandler.Packet{{ - Frames: []wire.Frame{sf}, - EncryptionLevel: protocol.EncryptionSecure, - }} + sph.EXPECT().DequeuePacketForRetransmission().Return( + &ackhandler.Packet{ + Frames: []wire.Frame{sf}, + EncryptionLevel: protocol.EncryptionSecure, + }) + sph.EXPECT().DequeuePacketForRetransmission() err := sess.sendPacket() Expect(err).ToNot(HaveOccurred()) Expect(mconn.written).To(BeEmpty()) @@ -918,17 +890,19 @@ var _ = Describe("Session", func() { StreamID: 0x5, Data: []byte("foobar1234567"), } - p := ackhandler.Packet{ - PacketNumber: 0x1337, - Frames: []wire.Frame{&f}, - EncryptionLevel: protocol.EncryptionForwardSecure, - } - sph.retransmissionQueue = []*ackhandler.Packet{&p} - + sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{}) + sph.EXPECT().DequeuePacketForRetransmission().Return( + &ackhandler.Packet{ + PacketNumber: 0x1337, + Frames: []wire.Frame{&f}, + EncryptionLevel: protocol.EncryptionForwardSecure, + }) + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().SendingAllowed() err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) - Expect(sph.requestedStopWaiting).To(BeTrue()) Expect(mconn.written).To(Receive(ContainSubstring("foobar1234567"))) }) @@ -941,17 +915,22 @@ var _ = Describe("Session", func() { StreamID: 0x7, Data: []byte("loremipsum"), } - p1 := ackhandler.Packet{ + p1 := &ackhandler.Packet{ PacketNumber: 0x1337, Frames: []wire.Frame{&f1}, EncryptionLevel: protocol.EncryptionForwardSecure, } - p2 := ackhandler.Packet{ + p2 := &ackhandler.Packet{ PacketNumber: 0x1338, Frames: []wire.Frame{&f2}, EncryptionLevel: protocol.EncryptionForwardSecure, } - sph.retransmissionQueue = []*ackhandler.Packet{&p1, &p2} + sph.EXPECT().DequeuePacketForRetransmission().Return(p1) + sph.EXPECT().DequeuePacketForRetransmission().Return(p2) + sph.EXPECT().DequeuePacketForRetransmission() + sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{}) + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().SendingAllowed() err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) @@ -960,24 +939,6 @@ var _ = Describe("Session", func() { Expect(packet).To(ContainSubstring("foobar")) Expect(packet).To(ContainSubstring("loremipsum")) }) - - It("always attaches a StopWaiting to a packet that contains a retransmission", func() { - f := &wire.StreamFrame{ - StreamID: 0x5, - Data: bytes.Repeat([]byte{'f'}, int(1.5*float32(protocol.MaxPacketSize))), - } - sess.streamFramer.AddFrameForRetransmission(f) - - err := sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(mconn.written).To(HaveLen(2)) - sentPackets := sph.sentPackets - Expect(sentPackets).To(HaveLen(2)) - _, ok := sentPackets[0].Frames[0].(*wire.StopWaitingFrame) - Expect(ok).To(BeTrue()) - _, ok = sentPackets[1].Frames[0].(*wire.StopWaitingFrame) - Expect(ok).To(BeTrue()) - }) }) }) @@ -1032,10 +993,14 @@ var _ = Describe("Session", func() { }) It("sets the timer to the ack timer", func() { - rph := &mockReceivedPacketHandler{ackAlarm: time.Now().Add(10 * time.Millisecond)} - rph.nextAckFrame = &wire.AckFrame{LargestAcked: 0x1337} + rph := mocks.NewMockReceivedPacketHandler(mockCtrl) + rph.EXPECT().GetAckFrame().Return(&wire.AckFrame{LargestAcked: 0x1337}) + rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)).MinTimes(1) sess.receivedPacketHandler = rph - go sess.run() + go func() { + defer GinkgoRecover() + sess.run() + }() defer sess.Close(nil) time.Sleep(10 * time.Millisecond) Eventually(func() int { return len(mconn.written) }).ShouldNot(BeZero())