diff --git a/connection.go b/connection.go index 9674737d4..d247da43d 100644 --- a/connection.go +++ b/connection.go @@ -1280,11 +1280,11 @@ func (s *connection) handleFrames( } if log != nil { frames = append(frames, logutils.ConvertFrame(frame)) - // An error occurred handling a previous frame. - // Don't handle the current frame. - if handleErr != nil { - continue - } + } + // An error occurred handling a previous frame. + // Don't handle the current frame. + if handleErr != nil { + continue } if err := s.handleFrame(frame, encLevel, destConnID); err != nil { if log == nil { @@ -1314,7 +1314,6 @@ func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel err = s.handleStreamFrame(frame) case *wire.AckFrame: err = s.handleAckFrame(frame, encLevel) - wire.PutAckFrame(frame) case *wire.ConnectionCloseFrame: s.handleConnectionCloseFrame(frame) case *wire.ResetStreamFrame: diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 5924f84bd..3b1bf486f 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -11,7 +11,9 @@ import ( type SentPacketHandler interface { // SentPacket may modify the packet SentPacket(packet *Packet) - ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) + // ReceivedAck processes an ACK frame. + // It does not store a copy of the frame. + ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) ReceivedBytes(protocol.ByteCount) DropPackets(protocol.EncryptionLevel) ResetForRetry() error diff --git a/internal/wire/ack_frame.go b/internal/wire/ack_frame.go index 2f816c78f..9b23cc25f 100644 --- a/internal/wire/ack_frame.go +++ b/internal/wire/ack_frame.go @@ -22,19 +22,17 @@ type AckFrame struct { } // parseAckFrame reads an ACK frame -func parseAckFrame(r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) (*AckFrame, error) { +func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) error { ecn := typ == ackECNFrameType - frame := GetAckFrame() - la, err := quicvarint.Read(r) if err != nil { - return nil, err + return err } largestAcked := protocol.PacketNumber(la) delay, err := quicvarint.Read(r) if err != nil { - return nil, err + return err } delayTime := time.Duration(delay*1< largestAcked { - return nil, errors.New("invalid first ACK range") + return errors.New("invalid first ACK range") } smallest := largestAcked - ackBlock @@ -65,50 +63,50 @@ func parseAckFrame(r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protoc for i := uint64(0); i < numBlocks; i++ { g, err := quicvarint.Read(r) if err != nil { - return nil, err + return err } gap := protocol.PacketNumber(g) if smallest < gap+2 { - return nil, errInvalidAckRanges + return errInvalidAckRanges } largest := smallest - gap - 2 ab, err := quicvarint.Read(r) if err != nil { - return nil, err + return err } ackBlock := protocol.PacketNumber(ab) if ackBlock > largest { - return nil, errInvalidAckRanges + return errInvalidAckRanges } smallest = largest - ackBlock frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest}) } if !frame.validateAckRanges() { - return nil, errInvalidAckRanges + return errInvalidAckRanges } if ecn { ect0, err := quicvarint.Read(r) if err != nil { - return nil, err + return err } frame.ECT0 = ect0 ect1, err := quicvarint.Read(r) if err != nil { - return nil, err + return err } frame.ECT1 = ect1 ecnce, err := quicvarint.Read(r) if err != nil { - return nil, err + return err } frame.ECNCE = ecnce } - return frame, nil + return nil } // Append appends an ACK frame. @@ -251,6 +249,18 @@ func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { return p <= f.AckRanges[i].Largest } +func (f *AckFrame) Reset() { + f.DelayTime = 0 + f.ECT0 = 0 + f.ECT1 = 0 + f.ECNCE = 0 + for _, r := range f.AckRanges { + r.Largest = 0 + r.Smallest = 0 + } + f.AckRanges = f.AckRanges[:0] +} + func encodeAckDelay(delay time.Duration) uint64 { return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent))) } diff --git a/internal/wire/ack_frame_test.go b/internal/wire/ack_frame_test.go index f80335c0c..c94c157d6 100644 --- a/internal/wire/ack_frame_test.go +++ b/internal/wire/ack_frame_test.go @@ -21,8 +21,8 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(10)...) // first ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, ackFrameType, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) + var frame AckFrame + Expect(parseAckFrame(&frame, b, ackFrameType, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) Expect(frame.HasMissingRanges()).To(BeFalse()) @@ -35,8 +35,8 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(0)...) // first ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, ackFrameType, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) + var frame AckFrame + Expect(parseAckFrame(&frame, b, ackFrameType, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(55))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(55))) Expect(frame.HasMissingRanges()).To(BeFalse()) @@ -49,8 +49,8 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(20)...) // first ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, ackFrameType, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) + var frame AckFrame + Expect(parseAckFrame(&frame, b, ackFrameType, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(20))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(0))) Expect(frame.HasMissingRanges()).To(BeFalse()) @@ -62,8 +62,8 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // delay data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(21)...) // first ack block - _, err := parseAckFrame(bytes.NewReader(data), ackFrameType, protocol.AckDelayExponent, protocol.Version1) - Expect(err).To(MatchError("invalid first ACK range")) + var frame AckFrame + Expect(parseAckFrame(&frame, bytes.NewReader(data), ackFrameType, protocol.AckDelayExponent, protocol.Version1)).To(MatchError("invalid first ACK range")) }) It("parses an ACK frame that has a single block", func() { @@ -74,7 +74,8 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(98)...) // gap data = append(data, encodeVarInt(50)...) // ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + var frame AckFrame + err := parseAckFrame(&frame, b, ackFrameType, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(1000))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(750))) @@ -96,7 +97,8 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(1)...) // gap data = append(data, encodeVarInt(1)...) // ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + var frame AckFrame + err := parseAckFrame(&frame, b, ackFrameType, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(94))) @@ -121,8 +123,8 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { r := bytes.NewReader(b) typ, err := quicvarint.Read(r) Expect(err).ToNot(HaveOccurred()) - frame, err := parseAckFrame(r, typ, protocol.AckDelayExponent+i, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) + var frame AckFrame + Expect(parseAckFrame(&frame, r, typ, protocol.AckDelayExponent+i, protocol.Version1)).To(Succeed()) Expect(frame.DelayTime).To(Equal(delayTime * (1 << i))) } }) @@ -133,7 +135,8 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(0)...) // first ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + var frame AckFrame + err := parseAckFrame(&frame, b, ackFrameType, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.DelayTime).To(BeNumerically(">", 0)) // The maximum encodable duration is ~292 years. @@ -147,11 +150,11 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(100)...) // first ack block data = append(data, encodeVarInt(98)...) // gap data = append(data, encodeVarInt(50)...) // ack block - _, err := parseAckFrame(bytes.NewReader(data), ackFrameType, protocol.AckDelayExponent, protocol.Version1) - Expect(err).NotTo(HaveOccurred()) + var frame AckFrame + Expect(parseAckFrame(&frame, bytes.NewReader(data), ackFrameType, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) for i := range data { - _, err := parseAckFrame(bytes.NewReader(data[:i]), ackFrameType, protocol.AckDelayExponent, protocol.Version1) - Expect(err).To(MatchError(io.EOF)) + var frame AckFrame + Expect(parseAckFrame(&frame, bytes.NewReader(data[:i]), ackFrameType, protocol.AckDelayExponent, protocol.Version1)).To(MatchError(io.EOF)) } }) @@ -165,7 +168,8 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0x12345)...) // ECT(1) data = append(data, encodeVarInt(0x12345678)...) // ECN-CE b := bytes.NewReader(data) - frame, err := parseAckFrame(b, ackECNFrameType, protocol.AckDelayExponent, protocol.Version1) + var frame AckFrame + err := parseAckFrame(&frame, b, ackECNFrameType, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) @@ -186,11 +190,11 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0x42)...) // ECT(0) data = append(data, encodeVarInt(0x12345)...) // ECT(1) data = append(data, encodeVarInt(0x12345678)...) // ECN-CE - _, err := parseAckFrame(bytes.NewReader(data), ackECNFrameType, protocol.AckDelayExponent, protocol.Version1) - Expect(err).NotTo(HaveOccurred()) + var frame AckFrame + Expect(parseAckFrame(&frame, bytes.NewReader(data), ackECNFrameType, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) for i := range data { - _, err = parseAckFrame(bytes.NewReader(data[:i]), ackECNFrameType, protocol.AckDelayExponent, protocol.Version1) - Expect(err).To(MatchError(io.EOF)) + var frame AckFrame + Expect(parseAckFrame(&frame, bytes.NewReader(data[:i]), ackECNFrameType, protocol.AckDelayExponent, protocol.Version1)).To(MatchError(io.EOF)) } }) }) @@ -243,9 +247,9 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { r := bytes.NewReader(b) typ, err := quicvarint.Read(r) Expect(err).ToNot(HaveOccurred()) - frame, err := parseAckFrame(r, typ, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) + var frame AckFrame + Expect(parseAckFrame(&frame, r, typ, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + Expect(&frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeFalse()) Expect(frame.DelayTime).To(Equal(f.DelayTime)) Expect(r.Len()).To(BeZero()) @@ -261,9 +265,9 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { r := bytes.NewReader(b) typ, err := quicvarint.Read(r) Expect(err).ToNot(HaveOccurred()) - frame, err := parseAckFrame(r, typ, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) + var frame AckFrame + Expect(parseAckFrame(&frame, r, typ, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + Expect(&frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeFalse()) Expect(r.Len()).To(BeZero()) }) @@ -282,9 +286,9 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { r := bytes.NewReader(b) typ, err := quicvarint.Read(r) Expect(err).ToNot(HaveOccurred()) - frame, err := parseAckFrame(r, typ, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) + var frame AckFrame + Expect(parseAckFrame(&frame, r, typ, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + Expect(&frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeTrue()) Expect(r.Len()).To(BeZero()) }) @@ -305,9 +309,9 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { r := bytes.NewReader(b) typ, err := quicvarint.Read(r) Expect(err).ToNot(HaveOccurred()) - frame, err := parseAckFrame(r, typ, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(Equal(f)) + var frame AckFrame + Expect(parseAckFrame(&frame, r, typ, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + Expect(&frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeTrue()) Expect(r.Len()).To(BeZero()) }) @@ -329,8 +333,8 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { r := bytes.NewReader(b) typ, err := quicvarint.Read(r) Expect(err).ToNot(HaveOccurred()) - frame, err := parseAckFrame(r, typ, protocol.AckDelayExponent, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) + var frame AckFrame + Expect(parseAckFrame(&frame, r, typ, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) Expect(frame.HasMissingRanges()).To(BeTrue()) Expect(r.Len()).To(BeZero()) Expect(len(frame.AckRanges)).To(BeNumerically("<", numRanges)) // make sure we dropped some ranges @@ -456,4 +460,21 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { Expect(f.AcksPacket(21)).To(BeFalse()) }) }) + + It("resets", func() { + f := &AckFrame{ + DelayTime: time.Second, + AckRanges: []AckRange{{Smallest: 1, Largest: 3}}, + ECT0: 1, + ECT1: 2, + ECNCE: 3, + } + f.Reset() + Expect(f.AckRanges).To(BeEmpty()) + Expect(f.AckRanges).To(HaveCap(1)) + Expect(f.DelayTime).To(BeZero()) + Expect(f.ECT0).To(BeZero()) + Expect(f.ECT1).To(BeZero()) + Expect(f.ECNCE).To(BeZero()) + }) }) diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index e624df94e..ff35dd101 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -39,9 +39,12 @@ const ( type frameParser struct { r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them - ackDelayExponent uint8 - + ackDelayExponent uint8 supportsDatagrams bool + + // To avoid allocating when parsing, keep a single ACK frame struct. + // It is used over and over again. + ackFrame *AckFrame } var _ FrameParser = &frameParser{} @@ -51,6 +54,7 @@ func NewFrameParser(supportsDatagrams bool) *frameParser { return &frameParser{ r: *bytes.NewReader(nil), supportsDatagrams: supportsDatagrams, + ackFrame: &AckFrame{}, } } @@ -105,7 +109,9 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol. if encLevel != protocol.Encryption1RTT { ackDelayExponent = protocol.DefaultAckDelayExponent } - frame, err = parseAckFrame(r, typ, ackDelayExponent, v) + p.ackFrame.Reset() + err = parseAckFrame(p.ackFrame, r, typ, ackDelayExponent, v) + frame = p.ackFrame case resetStreamFrameType: frame, err = parseResetStreamFrame(r, v) case stopSendingFrameType: