diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 9fcead66..207b95a7 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -25,12 +25,14 @@ type SentPacketHandler interface { // SentPacket may modify the packet SentPacket(packet *Packet) ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) error + ReceivedBytes(protocol.ByteCount) DropPackets(protocol.EncryptionLevel) ResetForRetry() error SetHandshakeComplete() // The SendMode determines if and what kind of packets can be sent. SendMode() SendMode + AmplificationWindow() protocol.ByteCount // TimeUntilSend is the time when the next packet should be sent. // It is used for pacing packets. TimeUntilSend() time.Time @@ -56,6 +58,7 @@ type SentPacketHandler interface { type sentPacketTracker interface { GetLowestPacketNotConfirmedAcked() protocol.PacketNumber + ReceivedPacket(protocol.EncryptionLevel) } // ReceivedPacketHandler handles ACKs needed to send for incoming packets diff --git a/internal/ackhandler/mock_sent_packet_tracker_test.go b/internal/ackhandler/mock_sent_packet_tracker_test.go index 9ababef4..9c3e8b8c 100644 --- a/internal/ackhandler/mock_sent_packet_tracker_test.go +++ b/internal/ackhandler/mock_sent_packet_tracker_test.go @@ -47,3 +47,15 @@ func (mr *MockSentPacketTrackerMockRecorder) GetLowestPacketNotConfirmedAcked() mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketTracker)(nil).GetLowestPacketNotConfirmedAcked)) } + +// ReceivedPacket mocks base method +func (m *MockSentPacketTracker) ReceivedPacket(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedPacket", arg0) +} + +// ReceivedPacket indicates an expected call of ReceivedPacket +func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockSentPacketTracker)(nil).ReceivedPacket), arg0) +} diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index 367d5355..b4e5f0b9 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -64,6 +64,7 @@ func (h *receivedPacketHandler) ReceivedPacket( rcvTime time.Time, shouldInstigateAck bool, ) error { + h.sentPackets.ReceivedPacket(encLevel) switch encLevel { case protocol.EncryptionInitial: h.initialPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index 985c6ce6..0018c68a 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -3,6 +3,8 @@ package ackhandler import ( "time" + "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -29,6 +31,9 @@ var _ = Describe("Received Packet Handler", func() { It("generates ACKs for different packet number spaces", func() { sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now().Add(-time.Second) + sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionInitial).Times(2) + sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionHandshake).Times(2) + sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT).Times(2) Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(5, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) @@ -54,6 +59,8 @@ var _ = Describe("Received Packet Handler", func() { It("uses the same packet number space for 0-RTT and 1-RTT packets", func() { sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() + sentPackets.EXPECT().ReceivedPacket(protocol.Encryption0RTT) + sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT) sendTime := time.Now().Add(-time.Second) Expect(handler.ReceivedPacket(2, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) @@ -64,6 +71,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("rejects 0-RTT packets with higher packet numbers than 1-RTT packets", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3) sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now() Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) @@ -72,6 +80,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("allows reordered 0-RTT packets", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3) sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now() Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) @@ -80,6 +89,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("drops Initial packets", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2) sendTime := time.Now().Add(-time.Second) Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) @@ -90,6 +100,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("drops Handshake packets", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2) sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now().Add(-time.Second) Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) @@ -105,6 +116,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("drops old ACK ranges", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).AnyTimes() sendTime := time.Now() sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Times(2) Expect(handler.ReceivedPacket(1, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index b3a3c40b..007c35ce 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -20,6 +20,8 @@ const ( timeThreshold = 9.0 / 8 // Maximum reordering in packets before packet threshold loss detection considers a packet lost. packetThreshold = 3 + // Before validating the client's address, the server won't send more than 3x bytes than it received. + amplificationFactor = 3 ) type packetNumberSpace struct { @@ -49,8 +51,16 @@ type sentPacketHandler struct { handshakePackets *packetNumberSpace appDataPackets *packetNumberSpace + // Do we know that the peer completed address validation yet? + // Always true for the server. peerCompletedAddressValidation bool - handshakeComplete bool + bytesReceived protocol.ByteCount + bytesSent protocol.ByteCount + // Have we validated the peer's address yet? + // Always true for the client. + peerAddressValidated bool + + handshakeComplete bool // lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived // example: we send an ACK for packets 90-100 with packet number 20 @@ -99,6 +109,7 @@ func newSentPacketHandler( return &sentPacketHandler{ peerCompletedAddressValidation: pers == protocol.PerspectiveServer, + peerAddressValidated: pers == protocol.PerspectiveClient, initialPackets: newPacketNumberSpace(initialPacketNumber), handshakePackets: newPacketNumberSpace(0), appDataPackets: newPacketNumberSpace(0), @@ -168,6 +179,16 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { h.ptoMode = SendNone } +func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) { + h.bytesReceived += n +} + +func (h *sentPacketHandler) ReceivedPacket(encLevel protocol.EncryptionLevel) { + if h.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionHandshake { + h.peerAddressValidated = true + } +} + func (h *sentPacketHandler) packetsInFlight() int { packetsInFlight := h.appDataPackets.history.Len() if h.handshakePackets != nil { @@ -180,6 +201,7 @@ func (h *sentPacketHandler) packetsInFlight() int { } func (h *sentPacketHandler) SentPacket(packet *Packet) { + h.bytesSent += packet.Length // For the client, drop the Initial packet number space when the first Handshake packet is sent. if h.perspective == protocol.PerspectiveClient && packet.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil { h.dropPackets(protocol.EncryptionInitial) @@ -638,6 +660,10 @@ func (h *sentPacketHandler) SendMode() SendMode { numTrackedPackets += h.handshakePackets.history.Len() } + if h.AmplificationWindow() == 0 { + h.logger.Debugf("Amplification window limited. Received %d bytes, already sent out %d bytes", h.bytesReceived, h.bytesSent) + return SendNone + } // Don't send any packets if we're keeping track of the maximum number of packets. // Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets, // we will stop sending out new data when reaching MaxOutstandingSentPackets, @@ -683,6 +709,16 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int { return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay))) } +func (h *sentPacketHandler) AmplificationWindow() protocol.ByteCount { + if h.peerAddressValidated { + return protocol.MaxByteCount + } + if h.bytesSent >= amplificationFactor*h.bytesReceived { + return 0 + } + return amplificationFactor*h.bytesReceived - h.bytesSent +} + func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool { pnSpace := h.getPacketNumberSpace(encLevel) p := pnSpace.history.FirstOutstanding() diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index e10bbb22..fb6df318 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -495,13 +495,51 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) }) - It("passes the bytes in flight to CanSend", func() { - handler.bytesInFlight = 42 - cong.EXPECT().CanSend(protocol.ByteCount(42)) + It("passes the bytes in flight to the congestion controller", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(42), gomock.Any(), protocol.ByteCount(42), true) + cong.EXPECT().TimeUntilSend(gomock.Any()) + handler.SentPacket(&Packet{ + Length: 42, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + cong.EXPECT().CanSend(protocol.ByteCount(42)).Return(true) handler.SendMode() }) + It("returns SendNone if limited by the 3x limit", func() { + handler.ReceivedBytes(100) + cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(300), gomock.Any(), protocol.ByteCount(300), true) + cong.EXPECT().TimeUntilSend(gomock.Any()) + handler.SentPacket(&Packet{ + Length: 300, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + cong.EXPECT().CanSend(protocol.ByteCount(300)).Return(true).AnyTimes() + Expect(handler.AmplificationWindow()).To(BeZero()) + Expect(handler.SendMode()).To(Equal(SendNone)) + }) + + It("limits the window to 3x the bytes received, to avoid amplification attacks", func() { + handler.ReceivedPacket(protocol.EncryptionInitial) // receiving an Initial packet doesn't validate the client's address + cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(50), gomock.Any(), protocol.ByteCount(50), true) + cong.EXPECT().TimeUntilSend(gomock.Any()) + handler.SentPacket(&Packet{ + Length: 50, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + handler.ReceivedBytes(100) + Expect(handler.AmplificationWindow()).To(Equal(protocol.ByteCount(3*100 - 50))) + }) + It("allows sending of ACKs when congestion limited", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) cong.EXPECT().CanSend(gomock.Any()).Return(true) Expect(handler.SendMode()).To(Equal(SendAny)) cong.EXPECT().CanSend(gomock.Any()).Return(false) @@ -509,6 +547,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("allows sending of ACKs when we're keeping track of MaxOutstandingSentPackets packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) cong.EXPECT().CanSend(gomock.Any()).Return(true).AnyTimes() cong.EXPECT().TimeUntilSend(gomock.Any()).AnyTimes() cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() @@ -521,6 +560,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("allows PTOs, even when congestion limited", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) // note that we don't EXPECT a call to GetCongestionWindow // that means retransmissions are sent without considering the congestion window handler.numProbesToSend = 1 @@ -561,6 +601,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("doesn't set an alarm if there are no outstanding packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11})) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 11}}} @@ -569,6 +610,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("does nothing on OnAlarm if there are no outstanding packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendAny)) }) @@ -602,6 +644,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("reset the PTO count when receiving an ACK", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) now := time.Now() handler.SetHandshakeComplete() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) @@ -615,6 +658,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("resets the PTO mode and PTO count when a packet number space is dropped", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) now := time.Now() handler.SentPacket(ackElicitingPacket(&Packet{ PacketNumber: 1, @@ -638,6 +682,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("allows two 1-RTT PTOs", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SetHandshakeComplete() var lostPackets []protocol.PacketNumber handler.SentPacket(ackElicitingPacket(&Packet{ @@ -657,6 +702,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("only counts ack-eliciting packets as probe packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SetHandshakeComplete() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) @@ -672,7 +718,8 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.SendMode()).ToNot(Equal(SendPTOAppData)) }) - It("gets two probe packets if RTO expires", func() { + It("gets two probe packets if PTO expires", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SetHandshakeComplete() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) @@ -698,6 +745,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("gets two probe packets if PTO expires, for Handshake packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SentPacket(initialPacket(&Packet{PacketNumber: 1})) handler.SentPacket(initialPacket(&Packet{PacketNumber: 2})) @@ -714,6 +762,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("doesn't send 1-RTT probe packets before the handshake completes", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) updateRTT(time.Hour) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP @@ -726,6 +775,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SetHandshakeComplete() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) @@ -737,6 +787,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("handles ACKs for the original packet", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now().Add(-time.Hour)})) handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) @@ -993,6 +1044,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("cancels the PTO when dropping a packet number space", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) now := time.Now() handler.SentPacket(handshakePacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) handler.SentPacket(handshakePacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) @@ -1028,12 +1080,15 @@ var _ = Describe("SentPacketHandler", func() { }) }) - Context("resetting for retry", func() { + Context("for the client", func() { BeforeEach(func() { perspective = protocol.PerspectiveClient }) - It("queues outstanding packets for retransmission, cancels alarms and resets PTO count", func() { + It("considers the server's address validated right away", func() { + }) + + It("queues outstanding packets for retransmission, cancels alarms and resets PTO count when receiving a Retry", func() { handler.SentPacket(initialPacket(&Packet{PacketNumber: 42})) Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) Expect(handler.bytesInFlight).ToNot(BeZero()) @@ -1047,7 +1102,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.ptoCount).To(BeZero()) }) - It("queues outstanding frames for retransmission and cancels alarms", func() { + It("queues outstanding frames for retransmission and cancels alarms when receiving a Retry", func() { var lostInitial, lost0RTT bool handler.SentPacket(&Packet{ PacketNumber: 13, diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 07e57747..55757604 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -38,6 +38,20 @@ func (m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder { return m.recorder } +// AmplificationWindow mocks base method +func (m *MockSentPacketHandler) AmplificationWindow() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AmplificationWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// AmplificationWindow indicates an expected call of AmplificationWindow +func (mr *MockSentPacketHandlerMockRecorder) AmplificationWindow() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AmplificationWindow", reflect.TypeOf((*MockSentPacketHandler)(nil).AmplificationWindow)) +} + // DropPackets mocks base method func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { m.ctrl.T.Helper() @@ -149,6 +163,18 @@ func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2) } +// ReceivedBytes mocks base method +func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedBytes", arg0) +} + +// ReceivedBytes indicates an expected call of ReceivedBytes +func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0) +} + // ResetForRetry mocks base method func (m *MockSentPacketHandler) ResetForRetry() error { m.ctrl.T.Helper() diff --git a/session.go b/session.go index 9627dcd1..862d8cbf 100644 --- a/session.go +++ b/session.go @@ -700,6 +700,7 @@ func (s *session) handlePacketImpl(rp *receivedPacket) bool { var processed bool data := rp.data p := rp + s.sentPacketHandler.ReceivedBytes(protocol.ByteCount(len(data))) for len(data) > 0 { if counter > 0 { p = p.Clone() @@ -1427,7 +1428,7 @@ func (s *session) sendPacket() (bool, error) { if !s.handshakeConfirmed { now := time.Now() - packet, err := s.packer.PackCoalescedPacket(protocol.MaxByteCount) + packet, err := s.packer.PackCoalescedPacket(s.sentPacketHandler.AmplificationWindow()) if err != nil || packet == nil { return false, err } diff --git a/session_test.go b/session_test.go index 52b88659..8213af93 100644 --- a/session_test.go +++ b/session_test.go @@ -1033,6 +1033,13 @@ var _ = Describe("Session", func() { It("sends packets", func() { sess.handshakeConfirmed = true + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ShouldSendNumPackets().Return(1000) + sph.EXPECT().SentPacket(gomock.Any()) + sess.sentPacketHandler = sph runSession() p := getPacket(1) packer.EXPECT().PackPacket().Return(p, nil) @@ -1069,6 +1076,13 @@ var _ = Describe("Session", func() { It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { sess.handshakeConfirmed = true + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ShouldSendNumPackets().Return(1000) + sph.EXPECT().SentPacket(gomock.Any()) + sess.sentPacketHandler = sph fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) fc.EXPECT().IsNewlyBlocked() @@ -1366,10 +1380,12 @@ var _ = Describe("Session", func() { It("sends coalesced packets before the handshake is confirmed", func() { sess.handshakeConfirmed = false sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + const window protocol.ByteCount = 321 + sph.EXPECT().AmplificationWindow().Return(window).AnyTimes() sess.sentPacketHandler = sph buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("foobar")...) - packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).Return(&coalescedPacket{ + packer.EXPECT().PackCoalescedPacket(window).Return(&coalescedPacket{ buffer: buffer, packets: []*packetContents{ { @@ -1394,7 +1410,7 @@ var _ = Describe("Session", func() { }, }, }, nil) - packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() + packer.EXPECT().PackCoalescedPacket(window).AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() @@ -1545,9 +1561,17 @@ var _ = Describe("Session", func() { }) It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().AmplificationWindow().Return(protocol.MaxByteCount) + sph.EXPECT().SetHandshakeComplete() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().ShouldSendNumPackets().AnyTimes().Return(10) + sess.sentPacketHandler = sph done := make(chan struct{}) sessionRunner.EXPECT().Retire(clientDestConnID) - packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).DoAndReturn(func(protocol.ByteCount) (*packedPacket, error) { + packer.EXPECT().PackCoalescedPacket(gomock.Any()).DoAndReturn(func(protocol.ByteCount) (*packedPacket, error) { frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(frames).ToNot(BeEmpty()) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) @@ -1559,7 +1583,7 @@ var _ = Describe("Session", func() { buffer: getPacketBuffer(), }, nil }) - packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() + packer.EXPECT().PackCoalescedPacket(gomock.Any()).AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() @@ -1659,6 +1683,7 @@ var _ = Describe("Session", func() { BeforeEach(func() { sess.config.MaxIdleTimeout = 30 * time.Second sess.config.KeepAlive = true + sess.receivedPacketHandler.ReceivedPacket(0, protocol.EncryptionHandshake, time.Now(), true) }) AfterEach(func() { @@ -2098,6 +2123,7 @@ var _ = Describe("Client Session", func() { It("handles Retry packets", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sess.sentPacketHandler = sph + sph.EXPECT().ReceivedBytes(gomock.Any()) sph.EXPECT().ResetForRetry() cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) packer.EXPECT().SetToken([]byte("foobar")) @@ -2333,6 +2359,7 @@ var _ = Describe("Client Session", func() { It("ignores Initial packets which use original source id, after accepting a Retry", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sess.sentPacketHandler = sph + sph.EXPECT().ReceivedBytes(gomock.Any()).Times(2) sph.EXPECT().ResetForRetry() newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID)