diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 226bfcbb..9079fb0b 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -7,26 +7,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) -// A Packet is a packet -type Packet struct { - PacketNumber protocol.PacketNumber - Frames []Frame - LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK - Length protocol.ByteCount - EncryptionLevel protocol.EncryptionLevel - SendTime time.Time - - IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller. - - includedInBytesInFlight bool - declaredLost bool - skippedPacket bool -} - -func (p *Packet) outstanding() bool { - return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket -} - // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { // SentPacket may modify the packet diff --git a/internal/ackhandler/packet.go b/internal/ackhandler/packet.go new file mode 100644 index 00000000..b8a47b7a --- /dev/null +++ b/internal/ackhandler/packet.go @@ -0,0 +1,49 @@ +package ackhandler + +import ( + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A Packet is a packet +type Packet struct { + PacketNumber protocol.PacketNumber + Frames []Frame + LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK + Length protocol.ByteCount + EncryptionLevel protocol.EncryptionLevel + SendTime time.Time + + IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller. + + includedInBytesInFlight bool + declaredLost bool + skippedPacket bool +} + +func (p *Packet) outstanding() bool { + return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket +} + +var packetPool = sync.Pool{New: func() any { return &Packet{} }} + +func GetPacket() *Packet { + p := packetPool.Get().(*Packet) + p.PacketNumber = 0 + p.Frames = nil + p.LargestAcked = 0 + p.Length = 0 + p.EncryptionLevel = protocol.EncryptionLevel(0) + p.SendTime = time.Time{} + p.IsPathMTUProbePacket = false + p.includedInBytesInFlight = false + p.declaredLost = false + p.skippedPacket = false + return p +} + +// We currently only return Packets back into the pool when they're acknowledged (not when they're lost). +// This simplifies the code, and gives the vast majority of the performance benefit we can gain from using the pool. +func putPacket(p *Packet) { packetPool.Put(p) } diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 5a8cd70e..03804f0a 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -334,7 +334,11 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En acked1RTTPacket = true } h.removeFromBytesInFlight(p) + putPacket(p) } + // After this point, we must not use ackedPackets any longer! + // We've already returned the buffers. + ackedPackets = nil //nolint:ineffassign // This is just to be on the safe side. // Reset the pto_count unless the client is unsure if the server has validated the client's address. if h.peerCompletedAddressValidation { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 7f8f8d02..d885d7e9 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -38,7 +38,7 @@ var _ = Describe("SentPacketHandler", func() { getPacket := func(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) *Packet { if el, ok := handler.getPacketNumberSpace(encLevel).history.packetMap[pn]; ok { - return &el.Value + return el.Value } return nil } diff --git a/internal/ackhandler/sent_packet_history.go b/internal/ackhandler/sent_packet_history.go index fca26911..2855e1b1 100644 --- a/internal/ackhandler/sent_packet_history.go +++ b/internal/ackhandler/sent_packet_history.go @@ -11,18 +11,18 @@ import ( type sentPacketHistory struct { rttStats *utils.RTTStats - outstandingPacketList *list.List[Packet] - etcPacketList *list.List[Packet] - packetMap map[protocol.PacketNumber]*list.Element[Packet] + outstandingPacketList *list.List[*Packet] + etcPacketList *list.List[*Packet] + packetMap map[protocol.PacketNumber]*list.Element[*Packet] highestSent protocol.PacketNumber } func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory { return &sentPacketHistory{ rttStats: rttStats, - outstandingPacketList: list.New[Packet](), - etcPacketList: list.New[Packet](), - packetMap: make(map[protocol.PacketNumber]*list.Element[Packet]), + outstandingPacketList: list.New[*Packet](), + etcPacketList: list.New[*Packet](), + packetMap: make(map[protocol.PacketNumber]*list.Element[*Packet]), highestSent: protocol.InvalidPacketNumber, } } @@ -33,7 +33,7 @@ func (h *sentPacketHistory) SentPacket(p *Packet, isAckEliciting bool) { } // Skipped packet numbers. for pn := h.highestSent + 1; pn < p.PacketNumber; pn++ { - el := h.etcPacketList.PushBack(Packet{ + el := h.etcPacketList.PushBack(&Packet{ PacketNumber: pn, EncryptionLevel: p.EncryptionLevel, SendTime: p.SendTime, @@ -44,11 +44,11 @@ func (h *sentPacketHistory) SentPacket(p *Packet, isAckEliciting bool) { h.highestSent = p.PacketNumber if isAckEliciting { - var el *list.Element[Packet] + var el *list.Element[*Packet] if p.outstanding() { - el = h.outstandingPacketList.PushBack(*p) + el = h.outstandingPacketList.PushBack(p) } else { - el = h.etcPacketList.PushBack(*p) + el = h.etcPacketList.PushBack(p) } h.packetMap[p.PacketNumber] = el } @@ -59,7 +59,7 @@ func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) err cont := true outstandingEl := h.outstandingPacketList.Front() etcEl := h.etcPacketList.Front() - var el *list.Element[Packet] + var el *list.Element[*Packet] // whichever has the next packet number is returned first for cont { if outstandingEl == nil || (etcEl != nil && etcEl.Value.PacketNumber < outstandingEl.Value.PacketNumber) { @@ -76,7 +76,7 @@ func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) err etcEl = etcEl.Next() } var err error - cont, err = cb(&el.Value) + cont, err = cb(el.Value) if err != nil { return err } @@ -90,7 +90,7 @@ func (h *sentPacketHistory) FirstOutstanding() *Packet { if el == nil { return nil } - return &el.Value + return el.Value } func (h *sentPacketHistory) Len() int { @@ -114,7 +114,7 @@ func (h *sentPacketHistory) HasOutstandingPackets() bool { func (h *sentPacketHistory) DeleteOldPackets(now time.Time) { maxAge := 3 * h.rttStats.PTO(false) - var nextEl *list.Element[Packet] + var nextEl *list.Element[*Packet] // we don't iterate outstandingPacketList, as we should not delete outstanding packets. // being outstanding for more than 3*PTO should only happen in the case of drastic RTT changes. for el := h.etcPacketList.Front(); el != nil; el = nextEl { @@ -145,10 +145,10 @@ func (h *sentPacketHistory) DeclareLost(p *Packet) *Packet { } } if el == nil { - el = h.etcPacketList.PushFront(*p) + el = h.etcPacketList.PushFront(p) } else { - el = h.etcPacketList.InsertAfter(*p, el) + el = h.etcPacketList.InsertAfter(p, el) } h.packetMap[p.PacketNumber] = el - return &el.Value + return el.Value } diff --git a/packet_packer.go b/packet_packer.go index 92faae7c..6b3840b2 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -100,15 +100,16 @@ func (p *packetContents) ToAckHandlerPacket(now time.Time, q *retransmissionQueu p.frames[i].OnLost = q.AddAppData } } - return &ackhandler.Packet{ - PacketNumber: p.header.PacketNumber, - LargestAcked: largestAcked, - Frames: p.frames, - Length: p.length, - EncryptionLevel: encLevel, - SendTime: now, - IsPathMTUProbePacket: p.isMTUProbePacket, - } + + ap := ackhandler.GetPacket() + ap.PacketNumber = p.header.PacketNumber + ap.LargestAcked = largestAcked + ap.Frames = p.frames + ap.Length = p.length + ap.EncryptionLevel = encLevel + ap.SendTime = now + ap.IsPathMTUProbePacket = p.isMTUProbePacket + return ap } func getMaxPacketSize(addr net.Addr) protocol.ByteCount {