From f12ee48617ffca43191740cd8c28c9928c52d683 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 5 May 2024 19:28:28 +0800 Subject: [PATCH] wire: use quicvarint.Parse when parsing frames (#4484) * wire: add benchmarks for the frame parser * wire: use quicvarint.Parse when parsing frames * wire: always use io.EOF for too short frames --- internal/wire/ack_frame.go | 59 ++++---- internal/wire/ack_frame_test.go | 107 +++++++------- internal/wire/connection_close_frame.go | 33 ++--- internal/wire/connection_close_frame_test.go | 25 ++-- internal/wire/crypto_frame.go | 25 ++-- internal/wire/crypto_frame_test.go | 12 +- internal/wire/data_blocked_frame.go | 10 +- internal/wire/data_blocked_frame_test.go | 11 +- internal/wire/datagram_frame.go | 23 ++- internal/wire/datagram_frame_test.go | 18 ++- internal/wire/frame_parser.go | 85 +++++------ internal/wire/frame_parser_test.go | 134 ++++++++++++++++++ internal/wire/max_data_frame.go | 10 +- internal/wire/max_data_frame_test.go | 11 +- internal/wire/max_stream_data_frame.go | 17 +-- internal/wire/max_stream_data_frame_test.go | 12 +- internal/wire/max_streams_frame.go | 11 +- internal/wire/max_streams_frame_test.go | 28 ++-- internal/wire/new_connection_id_frame.go | 47 +++--- internal/wire/new_connection_id_frame_test.go | 17 ++- internal/wire/new_token_frame.go | 22 ++- internal/wire/new_token_frame_test.go | 15 +- internal/wire/path_challenge_frame.go | 15 +- internal/wire/path_challenge_frame_test.go | 13 +- internal/wire/path_response_frame.go | 15 +- internal/wire/path_response_frame_test.go | 12 +- internal/wire/reset_stream_frame.go | 21 +-- internal/wire/reset_stream_frame_test.go | 11 +- internal/wire/retire_connection_id_frame.go | 10 +- .../wire/retire_connection_id_frame_test.go | 10 +- internal/wire/stop_sending_frame.go | 17 +-- internal/wire/stop_sending_frame_test.go | 12 +- internal/wire/stream_data_blocked_frame.go | 16 +-- .../wire/stream_data_blocked_frame_test.go | 11 +- internal/wire/stream_frame.go | 36 ++--- internal/wire/stream_frame_test.go | 49 +++---- internal/wire/streams_blocked_frame.go | 11 +- internal/wire/streams_blocked_frame_test.go | 34 +++-- 38 files changed, 572 insertions(+), 453 deletions(-) diff --git a/internal/wire/ack_frame.go b/internal/wire/ack_frame.go index 29be077b..e0f2db3c 100644 --- a/internal/wire/ack_frame.go +++ b/internal/wire/ack_frame.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "errors" "sort" "time" @@ -22,18 +21,21 @@ type AckFrame struct { } // parseAckFrame reads an ACK frame -func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.Version) error { +func parseAckFrame(frame *AckFrame, b []byte, typ uint64, ackDelayExponent uint8, _ protocol.Version) (int, error) { + startLen := len(b) ecn := typ == ackECNFrameType - la, err := quicvarint.Read(r) + la, l, err := quicvarint.Parse(b) if err != nil { - return err + return 0, replaceUnexpectedEOF(err) } + b = b[l:] largestAcked := protocol.PacketNumber(la) - delay, err := quicvarint.Read(r) + delay, l, err := quicvarint.Parse(b) if err != nil { - return err + return 0, replaceUnexpectedEOF(err) } + b = b[l:] delayTime := time.Duration(delay*1< largestAcked { - return errors.New("invalid first ACK range") + return 0, errors.New("invalid first ACK range") } smallest := largestAcked - ackBlock frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked}) // read all the other ACK ranges for i := uint64(0); i < numBlocks; i++ { - g, err := quicvarint.Read(r) + g, l, err := quicvarint.Parse(b) if err != nil { - return err + return 0, replaceUnexpectedEOF(err) } + b = b[l:] gap := protocol.PacketNumber(g) if smallest < gap+2 { - return errInvalidAckRanges + return 0, errInvalidAckRanges } largest := smallest - gap - 2 - ab, err := quicvarint.Read(r) + ab, l, err := quicvarint.Parse(b) if err != nil { - return err + return 0, replaceUnexpectedEOF(err) } + b = b[l:] ackBlock := protocol.PacketNumber(ab) if ackBlock > largest { - return errInvalidAckRanges + return 0, errInvalidAckRanges } smallest = largest - ackBlock frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest}) } if !frame.validateAckRanges() { - return errInvalidAckRanges + return 0, errInvalidAckRanges } if ecn { - ect0, err := quicvarint.Read(r) + ect0, l, err := quicvarint.Parse(b) if err != nil { - return err + return 0, replaceUnexpectedEOF(err) } + b = b[l:] frame.ECT0 = ect0 - ect1, err := quicvarint.Read(r) + ect1, l, err := quicvarint.Parse(b) if err != nil { - return err + return 0, replaceUnexpectedEOF(err) } + b = b[l:] frame.ECT1 = ect1 - ecnce, err := quicvarint.Read(r) + ecnce, l, err := quicvarint.Parse(b) if err != nil { - return err + return 0, replaceUnexpectedEOF(err) } + b = b[l:] frame.ECNCE = ecnce } - return nil + return startLen - len(b), nil } // Append appends an ACK frame. diff --git a/internal/wire/ack_frame_test.go b/internal/wire/ack_frame_test.go index c94c157d..d201331a 100644 --- a/internal/wire/ack_frame_test.go +++ b/internal/wire/ack_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "math" "time" @@ -13,20 +12,20 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("ACK Frame (for IETF QUIC)", func() { +var _ = Describe("ACK Frame", func() { Context("parsing", func() { It("parses an ACK frame without any ranges", func() { data := encodeVarInt(100) // largest acked data = append(data, encodeVarInt(0)...) // delay data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(10)...) // first ack block - b := bytes.NewReader(data) var frame AckFrame - Expect(parseAckFrame(&frame, b, ackFrameType, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(data))) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) }) It("parses an ACK frame that only acks a single packet", func() { @@ -34,13 +33,13 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // delay data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(0)...) // first ack block - b := bytes.NewReader(data) var frame AckFrame - Expect(parseAckFrame(&frame, b, ackFrameType, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(data))) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(55))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(55))) Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) }) It("accepts an ACK frame that acks all packets from 0 to largest", func() { @@ -48,13 +47,13 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // delay data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(20)...) // first ack block - b := bytes.NewReader(data) var frame AckFrame - Expect(parseAckFrame(&frame, b, ackFrameType, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(data))) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(20))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(0))) Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) }) It("rejects an ACK frame that has a first ACK block which is larger than LargestAcked", func() { @@ -63,7 +62,8 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(21)...) // first ack block var frame AckFrame - Expect(parseAckFrame(&frame, bytes.NewReader(data), ackFrameType, protocol.AckDelayExponent, protocol.Version1)).To(MatchError("invalid first ACK range")) + _, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + Expect(err).To(MatchError("invalid first ACK range")) }) It("parses an ACK frame that has a single block", func() { @@ -73,10 +73,10 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(100)...) // first ack block data = append(data, encodeVarInt(98)...) // gap data = append(data, encodeVarInt(50)...) // ack block - b := bytes.NewReader(data) var frame AckFrame - err := parseAckFrame(&frame, b, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(data))) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(1000))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(750))) Expect(frame.HasMissingRanges()).To(BeTrue()) @@ -84,7 +84,6 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { {Largest: 1000, Smallest: 900}, {Largest: 800, Smallest: 750}, })) - Expect(b.Len()).To(BeZero()) }) It("parses an ACK frame that has a multiple blocks", func() { @@ -96,10 +95,10 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // ack block data = append(data, encodeVarInt(1)...) // gap data = append(data, encodeVarInt(1)...) // ack block - b := bytes.NewReader(data) var frame AckFrame - err := parseAckFrame(&frame, b, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(data))) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(94))) Expect(frame.HasMissingRanges()).To(BeTrue()) @@ -108,7 +107,6 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { {Largest: 98, Smallest: 98}, {Largest: 95, Smallest: 94}, })) - Expect(b.Len()).To(BeZero()) }) It("uses the ack delay exponent", func() { @@ -120,11 +118,12 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { b, err := f.Append(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) for i := uint8(0); i < 8; i++ { - r := bytes.NewReader(b) - typ, err := quicvarint.Read(r) + typ, l, err := quicvarint.Parse(b) Expect(err).ToNot(HaveOccurred()) var frame AckFrame - Expect(parseAckFrame(&frame, r, typ, protocol.AckDelayExponent+i, protocol.Version1)).To(Succeed()) + n, err := parseAckFrame(&frame, b[l:], typ, protocol.AckDelayExponent+i, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(b[l:]))) Expect(frame.DelayTime).To(Equal(delayTime * (1 << i))) } }) @@ -134,9 +133,8 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(math.MaxUint64/5)...) // delay data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(0)...) // first ack block - b := bytes.NewReader(data) var frame AckFrame - err := parseAckFrame(&frame, b, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + _, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.DelayTime).To(BeNumerically(">", 0)) // The maximum encodable duration is ~292 years. @@ -151,10 +149,12 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(98)...) // gap data = append(data, encodeVarInt(50)...) // ack block var frame AckFrame - Expect(parseAckFrame(&frame, bytes.NewReader(data), ackFrameType, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + _, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) for i := range data { var frame AckFrame - Expect(parseAckFrame(&frame, bytes.NewReader(data[:i]), ackFrameType, protocol.AckDelayExponent, protocol.Version1)).To(MatchError(io.EOF)) + _, err := parseAckFrame(&frame, data[:i], ackFrameType, protocol.AckDelayExponent, protocol.Version1) + Expect(err).To(MatchError(io.EOF)) } }) @@ -167,17 +167,16 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0x42)...) // ECT(0) data = append(data, encodeVarInt(0x12345)...) // ECT(1) data = append(data, encodeVarInt(0x12345678)...) // ECN-CE - b := bytes.NewReader(data) var frame AckFrame - err := parseAckFrame(&frame, b, ackECNFrameType, protocol.AckDelayExponent, protocol.Version1) + n, err := parseAckFrame(&frame, data, ackECNFrameType, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(data))) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) Expect(frame.HasMissingRanges()).To(BeFalse()) Expect(frame.ECT0).To(BeEquivalentTo(0x42)) Expect(frame.ECT1).To(BeEquivalentTo(0x12345)) Expect(frame.ECNCE).To(BeEquivalentTo(0x12345678)) - Expect(b.Len()).To(BeZero()) }) It("errors on EOF", func() { @@ -191,10 +190,13 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0x12345)...) // ECT(1) data = append(data, encodeVarInt(0x12345678)...) // ECN-CE var frame AckFrame - Expect(parseAckFrame(&frame, bytes.NewReader(data), ackECNFrameType, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + n, err := parseAckFrame(&frame, data, ackECNFrameType, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(data))) for i := range data { var frame AckFrame - Expect(parseAckFrame(&frame, bytes.NewReader(data[:i]), ackECNFrameType, protocol.AckDelayExponent, protocol.Version1)).To(MatchError(io.EOF)) + _, err := parseAckFrame(&frame, data[:i], ackECNFrameType, protocol.AckDelayExponent, protocol.Version1) + Expect(err).To(MatchError(io.EOF)) } }) }) @@ -244,15 +246,16 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { b, err := f.Append(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(b).To(HaveLen(int(f.Length(protocol.Version1)))) - r := bytes.NewReader(b) - typ, err := quicvarint.Read(r) + typ, l, err := quicvarint.Parse(b) Expect(err).ToNot(HaveOccurred()) + b = b[l:] var frame AckFrame - Expect(parseAckFrame(&frame, r, typ, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + n, err := parseAckFrame(&frame, b, typ, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(b))) Expect(&frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeFalse()) Expect(frame.DelayTime).To(Equal(f.DelayTime)) - Expect(r.Len()).To(BeZero()) }) It("writes a frame that acks many packets", func() { @@ -262,14 +265,15 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { b, err := f.Append(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(b).To(HaveLen(int(f.Length(protocol.Version1)))) - r := bytes.NewReader(b) - typ, err := quicvarint.Read(r) + typ, l, err := quicvarint.Parse(b) Expect(err).ToNot(HaveOccurred()) + b = b[l:] var frame AckFrame - Expect(parseAckFrame(&frame, r, typ, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + n, err := parseAckFrame(&frame, b, typ, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(b))) Expect(&frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(r.Len()).To(BeZero()) }) It("writes a frame with a a single gap", func() { @@ -283,14 +287,15 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { b, err := f.Append(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(b).To(HaveLen(int(f.Length(protocol.Version1)))) - r := bytes.NewReader(b) - typ, err := quicvarint.Read(r) + typ, l, err := quicvarint.Parse(b) Expect(err).ToNot(HaveOccurred()) + b = b[l:] var frame AckFrame - Expect(parseAckFrame(&frame, r, typ, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + n, err := parseAckFrame(&frame, b, typ, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(b))) Expect(&frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(r.Len()).To(BeZero()) }) It("writes a frame with multiple ranges", func() { @@ -306,14 +311,15 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { b, err := f.Append(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(b).To(HaveLen(int(f.Length(protocol.Version1)))) - r := bytes.NewReader(b) - typ, err := quicvarint.Read(r) + typ, l, err := quicvarint.Parse(b) Expect(err).ToNot(HaveOccurred()) + b = b[l:] var frame AckFrame - Expect(parseAckFrame(&frame, r, typ, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + n, err := parseAckFrame(&frame, b, typ, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(b))) Expect(&frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(r.Len()).To(BeZero()) }) It("limits the maximum size of the ACK frame", func() { @@ -330,13 +336,14 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { // make sure the ACK frame is *a little bit* smaller than the MaxAckFrameSize Expect(len(b)).To(BeNumerically(">", protocol.MaxAckFrameSize-5)) Expect(len(b)).To(BeNumerically("<=", protocol.MaxAckFrameSize)) - r := bytes.NewReader(b) - typ, err := quicvarint.Read(r) + typ, l, err := quicvarint.Parse(b) Expect(err).ToNot(HaveOccurred()) + b = b[l:] var frame AckFrame - Expect(parseAckFrame(&frame, r, typ, protocol.AckDelayExponent, protocol.Version1)).To(Succeed()) + n, err := parseAckFrame(&frame, b, typ, protocol.AckDelayExponent, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(b))) Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(r.Len()).To(BeZero()) Expect(len(frame.AckRanges)).To(BeNumerically("<", numRanges)) // make sure we dropped some ranges }) }) diff --git a/internal/wire/connection_close_frame.go b/internal/wire/connection_close_frame.go index fee16478..be11a1b2 100644 --- a/internal/wire/connection_close_frame.go +++ b/internal/wire/connection_close_frame.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -16,40 +15,38 @@ type ConnectionCloseFrame struct { ReasonPhrase string } -func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, error) { +func parseConnectionCloseFrame(b []byte, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, int, error) { + startLen := len(b) f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType} - ec, err := quicvarint.Read(r) + ec, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } + b = b[l:] f.ErrorCode = ec // read the Frame Type, if this is not an application error if !f.IsApplicationError { - ft, err := quicvarint.Read(r) + ft, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } + b = b[l:] f.FrameType = ft } var reasonPhraseLen uint64 - reasonPhraseLen, err = quicvarint.Read(r) + reasonPhraseLen, l, err = quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } - // shortcut to prevent the unnecessary allocation of dataLen bytes - // if the dataLen is larger than the remaining length of the packet - // reading the whole reason phrase would result in EOF when attempting to READ - if int(reasonPhraseLen) > r.Len() { - return nil, io.EOF + b = b[l:] + if int(reasonPhraseLen) > len(b) { + return nil, 0, io.EOF } reasonPhrase := make([]byte, reasonPhraseLen) - if _, err := io.ReadFull(r, reasonPhrase); err != nil { - // this should never happen, since we already checked the reasonPhraseLen earlier - return nil, err - } + copy(reasonPhrase, b) f.ReasonPhrase = string(reasonPhrase) - return f, nil + return f, startLen - len(b) + int(reasonPhraseLen), nil } // Length of a written frame diff --git a/internal/wire/connection_close_frame_test.go b/internal/wire/connection_close_frame_test.go index 0b36c1f7..a7db15fb 100644 --- a/internal/wire/connection_close_frame_test.go +++ b/internal/wire/connection_close_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -18,14 +17,13 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { data = append(data, encodeVarInt(0x1337)...) // frame type data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length data = append(data, []byte(reason)...) - b := bytes.NewReader(data) - frame, err := parseConnectionCloseFrame(b, connectionCloseFrameType, protocol.Version1) + frame, l, err := parseConnectionCloseFrame(data, connectionCloseFrameType, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.IsApplicationError).To(BeFalse()) Expect(frame.ErrorCode).To(BeEquivalentTo(0x19)) Expect(frame.FrameType).To(BeEquivalentTo(0x1337)) Expect(frame.ReasonPhrase).To(Equal(reason)) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("accepts sample frame containing an application error code", func() { @@ -33,20 +31,19 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { data := encodeVarInt(0xcafe) data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length data = append(data, reason...) - b := bytes.NewReader(data) - frame, err := parseConnectionCloseFrame(b, applicationCloseFrameType, protocol.Version1) + frame, l, err := parseConnectionCloseFrame(data, applicationCloseFrameType, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.IsApplicationError).To(BeTrue()) Expect(frame.ErrorCode).To(BeEquivalentTo(0xcafe)) Expect(frame.ReasonPhrase).To(Equal(reason)) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("rejects long reason phrases", func() { data := encodeVarInt(0xcafe) data = append(data, encodeVarInt(0x42)...) // frame type data = append(data, encodeVarInt(0xffff)...) // reason phrase length - _, err := parseConnectionCloseFrame(bytes.NewReader(data), connectionCloseFrameType, protocol.Version1) + _, _, err := parseConnectionCloseFrame(data, connectionCloseFrameType, protocol.Version1) Expect(err).To(MatchError(io.EOF)) }) @@ -56,12 +53,11 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { data = append(data, encodeVarInt(0x1337)...) // frame type data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length data = append(data, []byte(reason)...) - b := bytes.NewReader(data) - _, err := parseConnectionCloseFrame(b, connectionCloseFrameType, protocol.Version1) + _, l, err := parseConnectionCloseFrame(data, connectionCloseFrameType, protocol.Version1) + Expect(l).To(Equal(len(data))) Expect(err).NotTo(HaveOccurred()) for i := range data { - b := bytes.NewReader(data[:i]) - _, err = parseConnectionCloseFrame(b, connectionCloseFrameType, protocol.Version1) + _, _, err = parseConnectionCloseFrame(data[:i], connectionCloseFrameType, protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) @@ -70,11 +66,10 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { data := encodeVarInt(0xcafe) data = append(data, encodeVarInt(0x42)...) // frame type data = append(data, encodeVarInt(0)...) - b := bytes.NewReader(data) - frame, err := parseConnectionCloseFrame(b, connectionCloseFrameType, protocol.Version1) + frame, l, err := parseConnectionCloseFrame(data, connectionCloseFrameType, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.ReasonPhrase).To(BeEmpty()) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) }) diff --git a/internal/wire/crypto_frame.go b/internal/wire/crypto_frame.go index c219064f..0aa7fe7b 100644 --- a/internal/wire/crypto_frame.go +++ b/internal/wire/crypto_frame.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -14,28 +13,28 @@ type CryptoFrame struct { Data []byte } -func parseCryptoFrame(r *bytes.Reader, _ protocol.Version) (*CryptoFrame, error) { +func parseCryptoFrame(b []byte, _ protocol.Version) (*CryptoFrame, int, error) { + startLen := len(b) frame := &CryptoFrame{} - offset, err := quicvarint.Read(r) + offset, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } + b = b[l:] frame.Offset = protocol.ByteCount(offset) - dataLen, err := quicvarint.Read(r) + dataLen, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } - if dataLen > uint64(r.Len()) { - return nil, io.EOF + b = b[l:] + if dataLen > uint64(len(b)) { + return nil, 0, io.EOF } if dataLen != 0 { frame.Data = make([]byte, dataLen) - if _, err := io.ReadFull(r, frame.Data); err != nil { - // this should never happen, since we already checked the dataLen earlier - return nil, err - } + copy(frame.Data, b) } - return frame, nil + return frame, startLen - len(b) + int(dataLen), nil } func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { diff --git a/internal/wire/crypto_frame_test.go b/internal/wire/crypto_frame_test.go index 9cdc8682..7eecab89 100644 --- a/internal/wire/crypto_frame_test.go +++ b/internal/wire/crypto_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -17,23 +16,22 @@ var _ = Describe("CRYPTO frame", func() { data := encodeVarInt(0xdecafbad) // offset data = append(data, encodeVarInt(6)...) // length data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseCryptoFrame(r, protocol.Version1) + frame, l, err := parseCryptoFrame(data, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) Expect(frame.Data).To(Equal([]byte("foobar"))) - Expect(r.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("errors on EOFs", func() { data := encodeVarInt(0xdecafbad) // offset data = append(data, encodeVarInt(6)...) // data length data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - _, err := parseCryptoFrame(r, protocol.Version1) + _, l, err := parseCryptoFrame(data, protocol.Version1) Expect(err).NotTo(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err := parseCryptoFrame(bytes.NewReader(data[:i]), protocol.Version1) + _, _, err := parseCryptoFrame(data[:i], protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/data_blocked_frame.go b/internal/wire/data_blocked_frame.go index 1593e520..c97d4c62 100644 --- a/internal/wire/data_blocked_frame.go +++ b/internal/wire/data_blocked_frame.go @@ -1,8 +1,6 @@ package wire import ( - "bytes" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) @@ -12,12 +10,12 @@ type DataBlockedFrame struct { MaximumData protocol.ByteCount } -func parseDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*DataBlockedFrame, error) { - offset, err := quicvarint.Read(r) +func parseDataBlockedFrame(b []byte, _ protocol.Version) (*DataBlockedFrame, int, error) { + offset, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } - return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, nil + return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, l, nil } func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, error) { diff --git a/internal/wire/data_blocked_frame_test.go b/internal/wire/data_blocked_frame_test.go index 60c5e7db..38e781ee 100644 --- a/internal/wire/data_blocked_frame_test.go +++ b/internal/wire/data_blocked_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -15,19 +14,19 @@ var _ = Describe("DATA_BLOCKED frame", func() { Context("when parsing", func() { It("accepts sample frame", func() { data := encodeVarInt(0x12345678) - b := bytes.NewReader(data) - frame, err := parseDataBlockedFrame(b, protocol.Version1) + frame, l, err := parseDataBlockedFrame(data, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.MaximumData).To(Equal(protocol.ByteCount(0x12345678))) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("errors on EOFs", func() { data := encodeVarInt(0x12345678) - _, err := parseDataBlockedFrame(bytes.NewReader(data), protocol.Version1) + _, l, err := parseDataBlockedFrame(data, protocol.Version1) Expect(err).ToNot(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err := parseDataBlockedFrame(bytes.NewReader(data[:i]), protocol.Version1) + _, _, err := parseDataBlockedFrame(data[:i], protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/datagram_frame.go b/internal/wire/datagram_frame.go index 6562fa5b..071fda9a 100644 --- a/internal/wire/datagram_frame.go +++ b/internal/wire/datagram_frame.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -20,29 +19,29 @@ type DatagramFrame struct { Data []byte } -func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*DatagramFrame, error) { +func parseDatagramFrame(b []byte, typ uint64, _ protocol.Version) (*DatagramFrame, int, error) { + startLen := len(b) f := &DatagramFrame{} f.DataLenPresent = typ&0x1 > 0 var length uint64 if f.DataLenPresent { var err error - len, err := quicvarint.Read(r) + var l int + length, l, err = quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } - if len > uint64(r.Len()) { - return nil, io.EOF + b = b[l:] + if length > uint64(len(b)) { + return nil, 0, io.EOF } - length = len } else { - length = uint64(r.Len()) + length = uint64(len(b)) } f.Data = make([]byte, length) - if _, err := io.ReadFull(r, f.Data); err != nil { - return nil, err - } - return f, nil + copy(f.Data, b) + return f, startLen - len(b) + int(length), nil } func (f *DatagramFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { diff --git a/internal/wire/datagram_frame_test.go b/internal/wire/datagram_frame_test.go index 67af7aec..214db3e3 100644 --- a/internal/wire/datagram_frame_test.go +++ b/internal/wire/datagram_frame_test.go @@ -16,29 +16,26 @@ var _ = Describe("STREAM frame", func() { It("parses a frame containing a length", func() { data := encodeVarInt(0x6) // length data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseDatagramFrame(r, 0x30^0x1, protocol.Version1) + frame, l, err := parseDatagramFrame(data, 0x30^0x1, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.Data).To(Equal([]byte("foobar"))) Expect(frame.DataLenPresent).To(BeTrue()) - Expect(r.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("parses a frame without length", func() { data := []byte("Lorem ipsum dolor sit amet") - r := bytes.NewReader(data) - frame, err := parseDatagramFrame(r, 0x30, protocol.Version1) + frame, l, err := parseDatagramFrame(data, 0x30, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.Data).To(Equal([]byte("Lorem ipsum dolor sit amet"))) Expect(frame.DataLenPresent).To(BeFalse()) - Expect(r.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("errors when the length is longer than the rest of the frame", func() { data := encodeVarInt(0x6) // length data = append(data, []byte("fooba")...) - r := bytes.NewReader(data) - _, err := parseDatagramFrame(r, 0x30^0x1, protocol.Version1) + _, _, err := parseDatagramFrame(data, 0x30^0x1, protocol.Version1) Expect(err).To(MatchError(io.EOF)) }) @@ -46,10 +43,11 @@ var _ = Describe("STREAM frame", func() { const typ = 0x30 ^ 0x1 data := encodeVarInt(6) // length data = append(data, []byte("foobar")...) - _, err := parseDatagramFrame(bytes.NewReader(data), typ, protocol.Version1) + _, l, err := parseDatagramFrame(data, typ, protocol.Version1) Expect(err).NotTo(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err = parseDatagramFrame(bytes.NewReader(data[0:i]), typ, protocol.Version1) + _, _, err = parseDatagramFrame(data[0:i], typ, protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index cf7d4cec..59d41444 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -1,9 +1,9 @@ package wire import ( - "bytes" "errors" "fmt" + "io" "reflect" "github.com/quic-go/quic-go/internal/protocol" @@ -38,8 +38,6 @@ const ( // The FrameParser parses QUIC frames, one by one. type FrameParser struct { - r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them - ackDelayExponent uint8 supportsDatagrams bool @@ -51,7 +49,6 @@ type FrameParser struct { // NewFrameParser creates a new frame parser. func NewFrameParser(supportsDatagrams bool) *FrameParser { return &FrameParser{ - r: *bytes.NewReader(nil), supportsDatagrams: supportsDatagrams, ackFrame: &AckFrame{}, } @@ -60,45 +57,46 @@ func NewFrameParser(supportsDatagrams bool) *FrameParser { // ParseNext parses the next frame. // It skips PADDING frames. func (p *FrameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (int, Frame, error) { - startLen := len(data) - p.r.Reset(data) - frame, err := p.parseNext(&p.r, encLevel, v) - n := startLen - p.r.Len() - p.r.Reset(nil) - return n, frame, err + frame, l, err := p.parseNext(data, encLevel, v) + return l, frame, err } -func (p *FrameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) { - for r.Len() != 0 { - typ, err := quicvarint.Read(r) +func (p *FrameParser) parseNext(b []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) { + var parsed int + for len(b) != 0 { + typ, l, err := quicvarint.Parse(b) + parsed += l if err != nil { - return nil, &qerr.TransportError{ + return nil, parsed, &qerr.TransportError{ ErrorCode: qerr.FrameEncodingError, ErrorMessage: err.Error(), } } + b = b[l:] if typ == 0x0 { // skip PADDING frames continue } - f, err := p.parseFrame(r, typ, encLevel, v) + f, l, err := p.parseFrame(b, typ, encLevel, v) + parsed += l if err != nil { - return nil, &qerr.TransportError{ + return nil, parsed, &qerr.TransportError{ FrameType: typ, ErrorCode: qerr.FrameEncodingError, ErrorMessage: err.Error(), } } - return f, nil + return f, parsed, nil } - return nil, nil + return nil, parsed, nil } -func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) { +func (p *FrameParser) parseFrame(b []byte, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) { var frame Frame var err error + var l int if typ&0xf8 == 0x8 { - frame, err = parseStreamFrame(r, typ, v) + frame, l, err = parseStreamFrame(b, typ, v) } else { switch typ { case pingFrameType: @@ -109,43 +107,43 @@ func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol. ackDelayExponent = protocol.DefaultAckDelayExponent } p.ackFrame.Reset() - err = parseAckFrame(p.ackFrame, r, typ, ackDelayExponent, v) + l, err = parseAckFrame(p.ackFrame, b, typ, ackDelayExponent, v) frame = p.ackFrame case resetStreamFrameType: - frame, err = parseResetStreamFrame(r, v) + frame, l, err = parseResetStreamFrame(b, v) case stopSendingFrameType: - frame, err = parseStopSendingFrame(r, v) + frame, l, err = parseStopSendingFrame(b, v) case cryptoFrameType: - frame, err = parseCryptoFrame(r, v) + frame, l, err = parseCryptoFrame(b, v) case newTokenFrameType: - frame, err = parseNewTokenFrame(r, v) + frame, l, err = parseNewTokenFrame(b, v) case maxDataFrameType: - frame, err = parseMaxDataFrame(r, v) + frame, l, err = parseMaxDataFrame(b, v) case maxStreamDataFrameType: - frame, err = parseMaxStreamDataFrame(r, v) + frame, l, err = parseMaxStreamDataFrame(b, v) case bidiMaxStreamsFrameType, uniMaxStreamsFrameType: - frame, err = parseMaxStreamsFrame(r, typ, v) + frame, l, err = parseMaxStreamsFrame(b, typ, v) case dataBlockedFrameType: - frame, err = parseDataBlockedFrame(r, v) + frame, l, err = parseDataBlockedFrame(b, v) case streamDataBlockedFrameType: - frame, err = parseStreamDataBlockedFrame(r, v) + frame, l, err = parseStreamDataBlockedFrame(b, v) case bidiStreamBlockedFrameType, uniStreamBlockedFrameType: - frame, err = parseStreamsBlockedFrame(r, typ, v) + frame, l, err = parseStreamsBlockedFrame(b, typ, v) case newConnectionIDFrameType: - frame, err = parseNewConnectionIDFrame(r, v) + frame, l, err = parseNewConnectionIDFrame(b, v) case retireConnectionIDFrameType: - frame, err = parseRetireConnectionIDFrame(r, v) + frame, l, err = parseRetireConnectionIDFrame(b, v) case pathChallengeFrameType: - frame, err = parsePathChallengeFrame(r, v) + frame, l, err = parsePathChallengeFrame(b, v) case pathResponseFrameType: - frame, err = parsePathResponseFrame(r, v) + frame, l, err = parsePathResponseFrame(b, v) case connectionCloseFrameType, applicationCloseFrameType: - frame, err = parseConnectionCloseFrame(r, typ, v) + frame, l, err = parseConnectionCloseFrame(b, typ, v) case handshakeDoneFrameType: frame = &HandshakeDoneFrame{} case 0x30, 0x31: if p.supportsDatagrams { - frame, err = parseDatagramFrame(r, typ, v) + frame, l, err = parseDatagramFrame(b, typ, v) break } fallthrough @@ -154,12 +152,12 @@ func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol. } } if err != nil { - return nil, err + return nil, 0, err } if !p.isAllowedAtEncLevel(frame, encLevel) { - return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel) + return nil, l, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel) } - return frame, nil + return frame, l, nil } func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { @@ -190,3 +188,10 @@ func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionL func (p *FrameParser) SetAckDelayExponent(exp uint8) { p.ackDelayExponent = exp } + +func replaceUnexpectedEOF(e error) error { + if e == io.ErrUnexpectedEOF { + return io.EOF + } + return e +} diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index 934d18d8..4f8ccefe 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -1,8 +1,12 @@ package wire import ( + "bytes" + "testing" "time" + "golang.org/x/exp/rand" + "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" @@ -433,3 +437,133 @@ var _ = Describe("Frame parsing", func() { }) }) }) + +// STREAM and ACK are the most relevant frames for high-throughput transfers. +func BenchmarkParseStreamAndACK(b *testing.B) { + ack := &AckFrame{ + AckRanges: []AckRange{ + {Smallest: 5000, Largest: 5200}, + {Smallest: 1, Largest: 4200}, + }, + DelayTime: 42 * time.Millisecond, + ECT0: 5000, + ECT1: 0, + ECNCE: 10, + } + sf := &StreamFrame{ + StreamID: 1337, + Offset: 1e7, + Data: make([]byte, 200), + DataLenPresent: true, + } + rand.Read(sf.Data) + + data, err := ack.Append([]byte{}, protocol.Version1) + if err != nil { + b.Fatal(err) + } + data, err = sf.Append(data, protocol.Version1) + if err != nil { + b.Fatal(err) + } + + parser := NewFrameParser(false) + parser.SetAckDelayExponent(3) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + l, f, err := parser.ParseNext(data, protocol.Encryption1RTT, protocol.Version1) + if err != nil { + b.Fatal(err) + } + ackParsed := f.(*AckFrame) + if ackParsed.DelayTime != ack.DelayTime || ackParsed.ECNCE != ack.ECNCE { + b.Fatalf("incorrect ACK frame: %v vs %v", ack, ackParsed) + } + l2, f, err := parser.ParseNext(data[l:], protocol.Encryption1RTT, protocol.Version1) + if err != nil { + b.Fatal(err) + } + if len(data[l:]) != l2 { + b.Fatal("didn't parse the entire packet") + } + sfParsed := f.(*StreamFrame) + if sfParsed.StreamID != sf.StreamID || !bytes.Equal(sfParsed.Data, sf.Data) { + b.Fatalf("incorrect STREAM frame: %v vs %v", sf, sfParsed) + } + } +} + +func BenchmarkParseOtherFrames(b *testing.B) { + maxDataFrame := &MaxDataFrame{MaximumData: 123456} + maxStreamsFrame := &MaxStreamsFrame{MaxStreamNum: 10} + maxStreamDataFrame := &MaxStreamDataFrame{StreamID: 1337, MaximumStreamData: 1e6} + cryptoFrame := &CryptoFrame{Offset: 1000, Data: make([]byte, 128)} + resetStreamFrame := &ResetStreamFrame{StreamID: 87654, ErrorCode: 1234, FinalSize: 1e8} + rand.Read(cryptoFrame.Data) + frames := []Frame{ + maxDataFrame, + maxStreamsFrame, + maxStreamDataFrame, + cryptoFrame, + &PingFrame{}, + resetStreamFrame, + } + var buf []byte + for i, frame := range frames { + var err error + buf, err = frame.Append(buf, protocol.Version1) + if err != nil { + b.Fatal(err) + } + if i == len(frames)/2 { + // add 3 PADDING frames + buf = append(buf, 0) + buf = append(buf, 0) + buf = append(buf, 0) + } + } + + parser := NewFrameParser(false) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + data := buf + for j := 0; j < len(frames); j++ { + l, f, err := parser.ParseNext(data, protocol.Encryption1RTT, protocol.Version1) + if err != nil { + b.Fatal(err) + } + data = data[l:] + switch j { + case 0: + if f.(*MaxDataFrame).MaximumData != maxDataFrame.MaximumData { + b.Fatalf("MAX_DATA frame does not match: %v vs %v", f, maxDataFrame) + } + case 1: + if f.(*MaxStreamsFrame).MaxStreamNum != maxStreamsFrame.MaxStreamNum { + b.Fatalf("MAX_STREAMS frame does not match: %v vs %v", f, maxStreamsFrame) + } + case 2: + if f.(*MaxStreamDataFrame).StreamID != maxStreamDataFrame.StreamID || + f.(*MaxStreamDataFrame).MaximumStreamData != maxStreamDataFrame.MaximumStreamData { + b.Fatalf("MAX_STREAM_DATA frame does not match: %v vs %v", f, maxStreamDataFrame) + } + case 3: + if f.(*CryptoFrame).Offset != cryptoFrame.Offset || !bytes.Equal(f.(*CryptoFrame).Data, cryptoFrame.Data) { + b.Fatalf("CRYPTO frame does not match: %v vs %v", f, cryptoFrame) + } + case 4: + _ = f.(*PingFrame) + case 5: + rst := f.(*ResetStreamFrame) + if rst.StreamID != resetStreamFrame.StreamID || rst.ErrorCode != resetStreamFrame.ErrorCode || + rst.FinalSize != resetStreamFrame.FinalSize { + b.Fatalf("RESET_STREAM frame does not match: %v vs %v", rst, resetStreamFrame) + } + } + } + } +} diff --git a/internal/wire/max_data_frame.go b/internal/wire/max_data_frame.go index a26e400e..5819c027 100644 --- a/internal/wire/max_data_frame.go +++ b/internal/wire/max_data_frame.go @@ -1,8 +1,6 @@ package wire import ( - "bytes" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) @@ -13,14 +11,14 @@ type MaxDataFrame struct { } // parseMaxDataFrame parses a MAX_DATA frame -func parseMaxDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxDataFrame, error) { +func parseMaxDataFrame(b []byte, _ protocol.Version) (*MaxDataFrame, int, error) { frame := &MaxDataFrame{} - byteOffset, err := quicvarint.Read(r) + byteOffset, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } frame.MaximumData = protocol.ByteCount(byteOffset) - return frame, nil + return frame, l, nil } func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { diff --git a/internal/wire/max_data_frame_test.go b/internal/wire/max_data_frame_test.go index 93980013..ed8b6d87 100644 --- a/internal/wire/max_data_frame_test.go +++ b/internal/wire/max_data_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -15,19 +14,19 @@ var _ = Describe("MAX_DATA frame", func() { Context("when parsing", func() { It("accepts sample frame", func() { data := encodeVarInt(0xdecafbad123456) // byte offset - b := bytes.NewReader(data) - frame, err := parseMaxDataFrame(b, protocol.Version1) + frame, l, err := parseMaxDataFrame(data, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.MaximumData).To(Equal(protocol.ByteCount(0xdecafbad123456))) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("errors on EOFs", func() { data := encodeVarInt(0xdecafbad1234567) // byte offset - _, err := parseMaxDataFrame(bytes.NewReader(data), protocol.Version1) + _, l, err := parseMaxDataFrame(data, protocol.Version1) Expect(err).NotTo(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err := parseMaxDataFrame(bytes.NewReader(data[:i]), protocol.Version1) + _, _, err := parseMaxDataFrame(data[:i], protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/max_stream_data_frame.go b/internal/wire/max_stream_data_frame.go index 6a0f5a9c..db9091af 100644 --- a/internal/wire/max_stream_data_frame.go +++ b/internal/wire/max_stream_data_frame.go @@ -1,8 +1,6 @@ package wire import ( - "bytes" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) @@ -13,20 +11,23 @@ type MaxStreamDataFrame struct { MaximumStreamData protocol.ByteCount } -func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxStreamDataFrame, error) { - sid, err := quicvarint.Read(r) +func parseMaxStreamDataFrame(b []byte, _ protocol.Version) (*MaxStreamDataFrame, int, error) { + startLen := len(b) + sid, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } - offset, err := quicvarint.Read(r) + b = b[l:] + offset, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } + b = b[l:] return &MaxStreamDataFrame{ StreamID: protocol.StreamID(sid), MaximumStreamData: protocol.ByteCount(offset), - }, nil + }, startLen - len(b), nil } func (f *MaxStreamDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { diff --git a/internal/wire/max_stream_data_frame_test.go b/internal/wire/max_stream_data_frame_test.go index 1ca7ff4f..f3a6e92a 100644 --- a/internal/wire/max_stream_data_frame_test.go +++ b/internal/wire/max_stream_data_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -16,22 +15,21 @@ var _ = Describe("MAX_STREAM_DATA frame", func() { It("accepts sample frame", func() { data := encodeVarInt(0xdeadbeef) // Stream ID data = append(data, encodeVarInt(0x12345678)...) // Offset - b := bytes.NewReader(data) - frame, err := parseMaxStreamDataFrame(b, protocol.Version1) + frame, l, err := parseMaxStreamDataFrame(data, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) Expect(frame.MaximumStreamData).To(Equal(protocol.ByteCount(0x12345678))) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("errors on EOFs", func() { data := encodeVarInt(0xdeadbeef) // Stream ID data = append(data, encodeVarInt(0x12345678)...) // Offset - b := bytes.NewReader(data) - _, err := parseMaxStreamDataFrame(b, protocol.Version1) + _, l, err := parseMaxStreamDataFrame(data, protocol.Version1) Expect(err).NotTo(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err := parseMaxStreamDataFrame(bytes.NewReader(data[:i]), protocol.Version1) + _, _, err := parseMaxStreamDataFrame(data[:i], protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/max_streams_frame.go b/internal/wire/max_streams_frame.go index ce6816ff..a8745bd1 100644 --- a/internal/wire/max_streams_frame.go +++ b/internal/wire/max_streams_frame.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "fmt" "github.com/quic-go/quic-go/internal/protocol" @@ -14,7 +13,7 @@ type MaxStreamsFrame struct { MaxStreamNum protocol.StreamNum } -func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*MaxStreamsFrame, error) { +func parseMaxStreamsFrame(b []byte, typ uint64, _ protocol.Version) (*MaxStreamsFrame, int, error) { f := &MaxStreamsFrame{} switch typ { case bidiMaxStreamsFrameType: @@ -22,15 +21,15 @@ func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*Max case uniMaxStreamsFrameType: f.Type = protocol.StreamTypeUni } - streamID, err := quicvarint.Read(r) + streamID, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } f.MaxStreamNum = protocol.StreamNum(streamID) if f.MaxStreamNum > protocol.MaxStreamCount { - return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum) + return nil, 0, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum) } - return f, nil + return f, l, nil } func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { diff --git a/internal/wire/max_streams_frame_test.go b/internal/wire/max_streams_frame_test.go index 7aee000c..ae41e849 100644 --- a/internal/wire/max_streams_frame_test.go +++ b/internal/wire/max_streams_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "fmt" "io" @@ -16,31 +15,30 @@ var _ = Describe("MAX_STREAMS frame", func() { Context("parsing", func() { It("accepts a frame for a bidirectional stream", func() { data := encodeVarInt(0xdecaf) - b := bytes.NewReader(data) - f, err := parseMaxStreamsFrame(b, bidiMaxStreamsFrameType, protocol.Version1) + f, l, err := parseMaxStreamsFrame(data, bidiMaxStreamsFrameType, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(f.Type).To(Equal(protocol.StreamTypeBidi)) Expect(f.MaxStreamNum).To(BeEquivalentTo(0xdecaf)) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("accepts a frame for a bidirectional stream", func() { data := encodeVarInt(0xdecaf) - b := bytes.NewReader(data) - f, err := parseMaxStreamsFrame(b, uniMaxStreamsFrameType, protocol.Version1) + f, l, err := parseMaxStreamsFrame(data, uniMaxStreamsFrameType, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(f.Type).To(Equal(protocol.StreamTypeUni)) Expect(f.MaxStreamNum).To(BeEquivalentTo(0xdecaf)) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("errors on EOFs", func() { const typ = 0x1d data := encodeVarInt(0xdeadbeefcafe13) - _, err := parseMaxStreamsFrame(bytes.NewReader(data), typ, protocol.Version1) + _, l, err := parseMaxStreamsFrame(data, typ, protocol.Version1) Expect(err).NotTo(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err = parseMaxStreamsFrame(bytes.NewReader(data[:i]), typ, protocol.Version1) + _, _, err := parseMaxStreamsFrame(data[:i], typ, protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) @@ -55,10 +53,10 @@ var _ = Describe("MAX_STREAMS frame", func() { } b, err := f.Append(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - r := bytes.NewReader(b) - typ, err := quicvarint.Read(r) + typ, l, err := quicvarint.Parse(b) Expect(err).ToNot(HaveOccurred()) - frame, err := parseMaxStreamsFrame(r, typ, protocol.Version1) + b = b[l:] + frame, _, err := parseMaxStreamsFrame(b, typ, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -70,10 +68,10 @@ var _ = Describe("MAX_STREAMS frame", func() { } b, err := f.Append(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - r := bytes.NewReader(b) - typ, err := quicvarint.Read(r) + typ, l, err := quicvarint.Parse(b) Expect(err).ToNot(HaveOccurred()) - _, err = parseMaxStreamsFrame(r, typ, protocol.Version1) + b = b[l:] + _, _, err = parseMaxStreamsFrame(b, typ, protocol.Version1) Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) }) } diff --git a/internal/wire/new_connection_id_frame.go b/internal/wire/new_connection_id_frame.go index f29773dd..852d46ef 100644 --- a/internal/wire/new_connection_id_frame.go +++ b/internal/wire/new_connection_id_frame.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "errors" "fmt" "io" @@ -18,43 +17,47 @@ type NewConnectionIDFrame struct { StatelessResetToken protocol.StatelessResetToken } -func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*NewConnectionIDFrame, error) { - seq, err := quicvarint.Read(r) +func parseNewConnectionIDFrame(b []byte, _ protocol.Version) (*NewConnectionIDFrame, int, error) { + startLen := len(b) + seq, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } - ret, err := quicvarint.Read(r) + b = b[l:] + ret, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } + b = b[l:] if ret > seq { //nolint:stylecheck - return nil, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq) + return nil, 0, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq) } - connIDLen, err := r.ReadByte() - if err != nil { - return nil, err + if len(b) == 0 { + return nil, 0, io.EOF } + connIDLen := int(b[0]) + b = b[1:] if connIDLen == 0 { - return nil, errors.New("invalid zero-length connection ID") + return nil, 0, errors.New("invalid zero-length connection ID") } - connID, err := protocol.ReadConnectionID(r, int(connIDLen)) - if err != nil { - return nil, err + if connIDLen > protocol.MaxConnIDLen { + return nil, 0, protocol.ErrInvalidConnectionIDLen + } + if len(b) < connIDLen { + return nil, 0, io.EOF } frame := &NewConnectionIDFrame{ SequenceNumber: seq, RetirePriorTo: ret, - ConnectionID: connID, + ConnectionID: protocol.ParseConnectionID(b[:connIDLen]), } - if _, err := io.ReadFull(r, frame.StatelessResetToken[:]); err != nil { - if err == io.ErrUnexpectedEOF { - return nil, io.EOF - } - return nil, err + b = b[connIDLen:] + if len(b) < len(frame.StatelessResetToken) { + return nil, 0, io.EOF } - - return frame, nil + copy(frame.StatelessResetToken[:], b) + return frame, startLen - len(b) + len(frame.StatelessResetToken), nil } func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { diff --git a/internal/wire/new_connection_id_frame_test.go b/internal/wire/new_connection_id_frame_test.go index 02ef3000..e29aa98b 100644 --- a/internal/wire/new_connection_id_frame_test.go +++ b/internal/wire/new_connection_id_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -18,14 +17,13 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { data = append(data, 10) // connection ID length data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token - b := bytes.NewReader(data) - frame, err := parseNewConnectionIDFrame(b, protocol.Version1) + frame, l, err := parseNewConnectionIDFrame(data, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) Expect(frame.RetirePriorTo).To(Equal(uint64(0xcafe))) Expect(frame.ConnectionID).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))) Expect(string(frame.StatelessResetToken[:])).To(Equal("deadbeefdecafbad")) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("errors when the Retire Prior To value is larger than the Sequence Number", func() { @@ -34,7 +32,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { data = append(data, 3) data = append(data, []byte{1, 2, 3}...) data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token - _, err := parseNewConnectionIDFrame(bytes.NewReader(data), protocol.Version1) + _, _, err := parseNewConnectionIDFrame(data, protocol.Version1) Expect(err).To(MatchError("Retire Prior To value (1001) larger than Sequence Number (1000)")) }) @@ -42,7 +40,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { data := encodeVarInt(42) // sequence number data = append(data, encodeVarInt(12)...) // retire prior to data = append(data, 0) // connection ID length - _, err := parseNewConnectionIDFrame(bytes.NewReader(data), protocol.Version1) + _, _, err := parseNewConnectionIDFrame(data, protocol.Version1) Expect(err).To(MatchError("invalid zero-length connection ID")) }) @@ -52,7 +50,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { data = append(data, 21) // connection ID length data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}...) // connection ID data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token - _, err := parseNewConnectionIDFrame(bytes.NewReader(data), protocol.Version1) + _, _, err := parseNewConnectionIDFrame(data, protocol.Version1) Expect(err).To(MatchError(protocol.ErrInvalidConnectionIDLen)) }) @@ -62,10 +60,11 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { data = append(data, 10) // connection ID length data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token - _, err := parseNewConnectionIDFrame(bytes.NewReader(data), protocol.Version1) + _, l, err := parseNewConnectionIDFrame(data, protocol.Version1) Expect(err).NotTo(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err := parseNewConnectionIDFrame(bytes.NewReader(data[:i]), protocol.Version1) + _, _, err := parseNewConnectionIDFrame(data[:i], protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/new_token_frame.go b/internal/wire/new_token_frame.go index b2c86321..f1d4d00f 100644 --- a/internal/wire/new_token_frame.go +++ b/internal/wire/new_token_frame.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "errors" "io" @@ -14,22 +13,21 @@ type NewTokenFrame struct { Token []byte } -func parseNewTokenFrame(r *bytes.Reader, _ protocol.Version) (*NewTokenFrame, error) { - tokenLen, err := quicvarint.Read(r) +func parseNewTokenFrame(b []byte, _ protocol.Version) (*NewTokenFrame, int, error) { + tokenLen, l, err := quicvarint.Parse(b) if err != nil { - return nil, err - } - if uint64(r.Len()) < tokenLen { - return nil, io.EOF + return nil, 0, replaceUnexpectedEOF(err) } + b = b[l:] if tokenLen == 0 { - return nil, errors.New("token must not be empty") + return nil, 0, errors.New("token must not be empty") + } + if uint64(len(b)) < tokenLen { + return nil, 0, io.EOF } token := make([]byte, int(tokenLen)) - if _, err := io.ReadFull(r, token); err != nil { - return nil, err - } - return &NewTokenFrame{Token: token}, nil + copy(token, b) + return &NewTokenFrame{Token: token}, l + int(tokenLen), nil } func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { diff --git a/internal/wire/new_token_frame_test.go b/internal/wire/new_token_frame_test.go index 92f1721f..d835895e 100644 --- a/internal/wire/new_token_frame_test.go +++ b/internal/wire/new_token_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -17,17 +16,15 @@ var _ = Describe("NEW_TOKEN frame", func() { token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." data := encodeVarInt(uint64(len(token))) data = append(data, token...) - b := bytes.NewReader(data) - f, err := parseNewTokenFrame(b, protocol.Version1) + f, l, err := parseNewTokenFrame(data, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(string(f.Token)).To(Equal(token)) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("rejects empty tokens", func() { data := encodeVarInt(0) - b := bytes.NewReader(data) - _, err := parseNewTokenFrame(b, protocol.Version1) + _, _, err := parseNewTokenFrame(data, protocol.Version1) Expect(err).To(MatchError("token must not be empty")) }) @@ -35,11 +32,11 @@ var _ = Describe("NEW_TOKEN frame", func() { token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit" data := encodeVarInt(uint64(len(token))) data = append(data, token...) - r := bytes.NewReader(data) - _, err := parseNewTokenFrame(r, protocol.Version1) + _, l, err := parseNewTokenFrame(data, protocol.Version1) Expect(err).NotTo(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err := parseNewTokenFrame(bytes.NewReader(data[:i]), protocol.Version1) + _, _, err := parseNewTokenFrame(data[:i], protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/path_challenge_frame.go b/internal/wire/path_challenge_frame.go index 772041ac..2aca989f 100644 --- a/internal/wire/path_challenge_frame.go +++ b/internal/wire/path_challenge_frame.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -12,15 +11,13 @@ type PathChallengeFrame struct { Data [8]byte } -func parsePathChallengeFrame(r *bytes.Reader, _ protocol.Version) (*PathChallengeFrame, error) { - frame := &PathChallengeFrame{} - if _, err := io.ReadFull(r, frame.Data[:]); err != nil { - if err == io.ErrUnexpectedEOF { - return nil, io.EOF - } - return nil, err +func parsePathChallengeFrame(b []byte, _ protocol.Version) (*PathChallengeFrame, int, error) { + f := &PathChallengeFrame{} + if len(b) < 8 { + return nil, 0, io.EOF } - return frame, nil + copy(f.Data[:], b) + return f, 8, nil } func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { diff --git a/internal/wire/path_challenge_frame_test.go b/internal/wire/path_challenge_frame_test.go index 0f15a417..b6d9b8e0 100644 --- a/internal/wire/path_challenge_frame_test.go +++ b/internal/wire/path_challenge_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -13,20 +12,20 @@ import ( var _ = Describe("PATH_CHALLENGE frame", func() { Context("when parsing", func() { It("accepts sample frame", func() { - b := bytes.NewReader([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - f, err := parsePathChallengeFrame(b, protocol.Version1) + b := []byte{1, 2, 3, 4, 5, 6, 7, 8} + f, l, err := parsePathChallengeFrame(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeZero()) Expect(f.Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(l).To(Equal(len(b))) }) It("errors on EOFs", func() { data := []byte{1, 2, 3, 4, 5, 6, 7, 8} - b := bytes.NewReader(data) - _, err := parsePathChallengeFrame(b, protocol.Version1) + _, l, err := parsePathChallengeFrame(data, protocol.Version1) Expect(err).ToNot(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err := parsePathChallengeFrame(bytes.NewReader(data[:i]), protocol.Version1) + _, _, err := parsePathChallengeFrame(data[:i], protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/path_response_frame.go b/internal/wire/path_response_frame.go index 86bbe619..76532c85 100644 --- a/internal/wire/path_response_frame.go +++ b/internal/wire/path_response_frame.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -12,15 +11,13 @@ type PathResponseFrame struct { Data [8]byte } -func parsePathResponseFrame(r *bytes.Reader, _ protocol.Version) (*PathResponseFrame, error) { - frame := &PathResponseFrame{} - if _, err := io.ReadFull(r, frame.Data[:]); err != nil { - if err == io.ErrUnexpectedEOF { - return nil, io.EOF - } - return nil, err +func parsePathResponseFrame(b []byte, _ protocol.Version) (*PathResponseFrame, int, error) { + f := &PathResponseFrame{} + if len(b) < 8 { + return nil, 0, io.EOF } - return frame, nil + copy(f.Data[:], b) + return f, 8, nil } func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { diff --git a/internal/wire/path_response_frame_test.go b/internal/wire/path_response_frame_test.go index 48e17742..f66c1de2 100644 --- a/internal/wire/path_response_frame_test.go +++ b/internal/wire/path_response_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -13,19 +12,20 @@ import ( var _ = Describe("PATH_RESPONSE frame", func() { Context("when parsing", func() { It("accepts sample frame", func() { - b := bytes.NewReader([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - f, err := parsePathResponseFrame(b, protocol.Version1) + b := []byte{1, 2, 3, 4, 5, 6, 7, 8} + f, l, err := parsePathResponseFrame(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeZero()) Expect(f.Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(l).To(Equal(len(b))) }) It("errors on EOFs", func() { data := []byte{1, 2, 3, 4, 5, 6, 7, 8} - _, err := parsePathResponseFrame(bytes.NewReader(data), protocol.Version1) + _, l, err := parsePathResponseFrame(data, protocol.Version1) Expect(err).NotTo(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err := parsePathResponseFrame(bytes.NewReader(data[:i]), protocol.Version1) + _, _, err := parsePathResponseFrame(data[:i], protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/reset_stream_frame.go b/internal/wire/reset_stream_frame.go index 4a3d0175..a20029af 100644 --- a/internal/wire/reset_stream_frame.go +++ b/internal/wire/reset_stream_frame.go @@ -1,8 +1,6 @@ package wire import ( - "bytes" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/quicvarint" @@ -15,21 +13,24 @@ type ResetStreamFrame struct { FinalSize protocol.ByteCount } -func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFrame, error) { +func parseResetStreamFrame(b []byte, _ protocol.Version) (*ResetStreamFrame, int, error) { + startLen := len(b) var streamID protocol.StreamID var byteOffset protocol.ByteCount - sid, err := quicvarint.Read(r) + sid, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } + b = b[l:] streamID = protocol.StreamID(sid) - errorCode, err := quicvarint.Read(r) + errorCode, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } - bo, err := quicvarint.Read(r) + b = b[l:] + bo, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } byteOffset = protocol.ByteCount(bo) @@ -37,7 +38,7 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFra StreamID: streamID, ErrorCode: qerr.StreamErrorCode(errorCode), FinalSize: byteOffset, - }, nil + }, startLen - len(b) + l, nil } func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { diff --git a/internal/wire/reset_stream_frame_test.go b/internal/wire/reset_stream_frame_test.go index 26512c54..61e33b0d 100644 --- a/internal/wire/reset_stream_frame_test.go +++ b/internal/wire/reset_stream_frame_test.go @@ -1,8 +1,6 @@ package wire import ( - "bytes" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/quicvarint" @@ -17,22 +15,23 @@ var _ = Describe("RESET_STREAM frame", func() { data := encodeVarInt(0xdeadbeef) // stream ID data = append(data, encodeVarInt(0x1337)...) // error code data = append(data, encodeVarInt(0x987654321)...) // byte offset - b := bytes.NewReader(data) - frame, err := parseResetStreamFrame(b, protocol.Version1) + frame, l, err := parseResetStreamFrame(data, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) Expect(frame.FinalSize).To(Equal(protocol.ByteCount(0x987654321))) Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337))) + Expect(l).To(Equal(len(data))) }) It("errors on EOFs", func() { data := encodeVarInt(0xdeadbeef) // stream ID data = append(data, encodeVarInt(0x1337)...) // error code data = append(data, encodeVarInt(0x987654321)...) // byte offset - _, err := parseResetStreamFrame(bytes.NewReader(data), protocol.Version1) + _, l, err := parseResetStreamFrame(data, protocol.Version1) Expect(err).NotTo(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err := parseResetStreamFrame(bytes.NewReader(data[:i]), protocol.Version1) + _, _, err := parseResetStreamFrame(data[:i], protocol.Version1) Expect(err).To(HaveOccurred()) } }) diff --git a/internal/wire/retire_connection_id_frame.go b/internal/wire/retire_connection_id_frame.go index 8a264706..27aeff84 100644 --- a/internal/wire/retire_connection_id_frame.go +++ b/internal/wire/retire_connection_id_frame.go @@ -1,8 +1,6 @@ package wire import ( - "bytes" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) @@ -12,12 +10,12 @@ type RetireConnectionIDFrame struct { SequenceNumber uint64 } -func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*RetireConnectionIDFrame, error) { - seq, err := quicvarint.Read(r) +func parseRetireConnectionIDFrame(b []byte, _ protocol.Version) (*RetireConnectionIDFrame, int, error) { + seq, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } - return &RetireConnectionIDFrame{SequenceNumber: seq}, nil + return &RetireConnectionIDFrame{SequenceNumber: seq}, l, nil } func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { diff --git a/internal/wire/retire_connection_id_frame_test.go b/internal/wire/retire_connection_id_frame_test.go index b679e41e..a6eae1f9 100644 --- a/internal/wire/retire_connection_id_frame_test.go +++ b/internal/wire/retire_connection_id_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -14,18 +13,19 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { Context("when parsing", func() { It("accepts a sample frame", func() { data := encodeVarInt(0xdeadbeef) // sequence number - b := bytes.NewReader(data) - frame, err := parseRetireConnectionIDFrame(b, protocol.Version1) + frame, l, err := parseRetireConnectionIDFrame(data, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) + Expect(l).To(Equal(len(data))) }) It("errors on EOFs", func() { data := encodeVarInt(0xdeadbeef) // sequence number - _, err := parseRetireConnectionIDFrame(bytes.NewReader(data), protocol.Version1) + _, l, err := parseRetireConnectionIDFrame(data, protocol.Version1) Expect(err).NotTo(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err := parseRetireConnectionIDFrame(bytes.NewReader(data[:i]), protocol.Version1) + _, _, err := parseRetireConnectionIDFrame(data[:i], protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/stop_sending_frame.go b/internal/wire/stop_sending_frame.go index 5f00a817..a2326f8e 100644 --- a/internal/wire/stop_sending_frame.go +++ b/internal/wire/stop_sending_frame.go @@ -1,8 +1,6 @@ package wire import ( - "bytes" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/quicvarint" @@ -15,20 +13,23 @@ type StopSendingFrame struct { } // parseStopSendingFrame parses a STOP_SENDING frame -func parseStopSendingFrame(r *bytes.Reader, _ protocol.Version) (*StopSendingFrame, error) { - streamID, err := quicvarint.Read(r) +func parseStopSendingFrame(b []byte, _ protocol.Version) (*StopSendingFrame, int, error) { + startLen := len(b) + streamID, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } - errorCode, err := quicvarint.Read(r) + b = b[l:] + errorCode, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } + b = b[l:] return &StopSendingFrame{ StreamID: protocol.StreamID(streamID), ErrorCode: qerr.StreamErrorCode(errorCode), - }, nil + }, startLen - len(b), nil } // Length of a written frame diff --git a/internal/wire/stop_sending_frame_test.go b/internal/wire/stop_sending_frame_test.go index bb67531d..abf350ec 100644 --- a/internal/wire/stop_sending_frame_test.go +++ b/internal/wire/stop_sending_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -17,22 +16,21 @@ var _ = Describe("STOP_SENDING frame", func() { It("parses a sample frame", func() { data := encodeVarInt(0xdecafbad) // stream ID data = append(data, encodeVarInt(0x1337)...) // error code - b := bytes.NewReader(data) - frame, err := parseStopSendingFrame(b, protocol.Version1) + frame, l, err := parseStopSendingFrame(data, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad))) Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337))) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("errors on EOFs", func() { data := encodeVarInt(0xdecafbad) // stream ID data = append(data, encodeVarInt(0x123456)...) // error code - b := bytes.NewReader(data) - _, err := parseStopSendingFrame(b, protocol.Version1) + _, l, err := parseStopSendingFrame(data, protocol.Version1) Expect(err).NotTo(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err := parseStopSendingFrame(bytes.NewReader(data[:i]), protocol.Version1) + _, _, err := parseStopSendingFrame(data[:i], protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/stream_data_blocked_frame.go b/internal/wire/stream_data_blocked_frame.go index 2dca5e51..3762ec76 100644 --- a/internal/wire/stream_data_blocked_frame.go +++ b/internal/wire/stream_data_blocked_frame.go @@ -1,8 +1,6 @@ package wire import ( - "bytes" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" ) @@ -13,20 +11,22 @@ type StreamDataBlockedFrame struct { MaximumStreamData protocol.ByteCount } -func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*StreamDataBlockedFrame, error) { - sid, err := quicvarint.Read(r) +func parseStreamDataBlockedFrame(b []byte, _ protocol.Version) (*StreamDataBlockedFrame, int, error) { + startLen := len(b) + sid, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } - offset, err := quicvarint.Read(r) + b = b[l:] + offset, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } return &StreamDataBlockedFrame{ StreamID: protocol.StreamID(sid), MaximumStreamData: protocol.ByteCount(offset), - }, nil + }, startLen - len(b) + l, nil } func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { diff --git a/internal/wire/stream_data_blocked_frame_test.go b/internal/wire/stream_data_blocked_frame_test.go index 677d211e..4d833a35 100644 --- a/internal/wire/stream_data_blocked_frame_test.go +++ b/internal/wire/stream_data_blocked_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "io" "github.com/quic-go/quic-go/internal/protocol" @@ -16,21 +15,21 @@ var _ = Describe("STREAM_DATA_BLOCKED frame", func() { It("accepts sample frame", func() { data := encodeVarInt(0xdeadbeef) // stream ID data = append(data, encodeVarInt(0xdecafbad)...) // offset - b := bytes.NewReader(data) - frame, err := parseStreamDataBlockedFrame(b, protocol.Version1) + frame, l, err := parseStreamDataBlockedFrame(data, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) Expect(frame.MaximumStreamData).To(Equal(protocol.ByteCount(0xdecafbad))) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("errors on EOFs", func() { data := encodeVarInt(0xdeadbeef) data = append(data, encodeVarInt(0xc0010ff)...) - _, err := parseStreamDataBlockedFrame(bytes.NewReader(data), protocol.Version1) + _, l, err := parseStreamDataBlockedFrame(data, protocol.Version1) Expect(err).NotTo(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err := parseStreamDataBlockedFrame(bytes.NewReader(data[:i]), protocol.Version1) + _, _, err := parseStreamDataBlockedFrame(data[:i], protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/stream_frame.go b/internal/wire/stream_frame.go index e7df3cea..f9470ecd 100644 --- a/internal/wire/stream_frame.go +++ b/internal/wire/stream_frame.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "errors" "io" @@ -20,33 +19,41 @@ type StreamFrame struct { fromPool bool } -func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamFrame, error) { +func parseStreamFrame(b []byte, typ uint64, _ protocol.Version) (*StreamFrame, int, error) { + startLen := len(b) hasOffset := typ&0b100 > 0 fin := typ&0b1 > 0 hasDataLen := typ&0b10 > 0 - streamID, err := quicvarint.Read(r) + streamID, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } + b = b[l:] var offset uint64 if hasOffset { - offset, err = quicvarint.Read(r) + offset, l, err = quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } + b = b[l:] } var dataLen uint64 if hasDataLen { var err error - dataLen, err = quicvarint.Read(r) + var l int + dataLen, l, err = quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) + } + b = b[l:] + if dataLen > uint64(len(b)) { + return nil, 0, io.EOF } } else { // The rest of the packet is data - dataLen = uint64(r.Len()) + dataLen = uint64(len(b)) } var frame *StreamFrame @@ -57,7 +64,7 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamF // The STREAM frame can't be larger than the StreamFrame we obtained from the buffer, // since those StreamFrames have a buffer length of the maximum packet size. if dataLen > uint64(cap(frame.Data)) { - return nil, io.EOF + return nil, 0, io.EOF } frame.Data = frame.Data[:dataLen] } @@ -68,17 +75,14 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamF frame.DataLenPresent = hasDataLen if dataLen != 0 { - if _, err := io.ReadFull(r, frame.Data); err != nil { - return nil, err - } + copy(frame.Data, b) } if frame.Offset+frame.DataLen() > protocol.MaxByteCount { - return nil, errors.New("stream data overflows maximum offset") + return nil, 0, errors.New("stream data overflows maximum offset") } - return frame, nil + return frame, startLen - len(b) + int(dataLen), nil } -// Write writes a STREAM frame func (f *StreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { if len(f.Data) == 0 && !f.Fin { return nil, errors.New("StreamFrame: attempting to write empty frame without FIN") diff --git a/internal/wire/stream_frame_test.go b/internal/wire/stream_frame_test.go index 1694865d..8bcc1edc 100644 --- a/internal/wire/stream_frame_test.go +++ b/internal/wire/stream_frame_test.go @@ -17,70 +17,73 @@ var _ = Describe("STREAM frame", func() { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(0xdecafbad)...) // offset data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, 0x8^0x4, protocol.Version1) + frame, l, err := parseStreamFrame(data, 0x8^0x4, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) Expect(frame.Data).To(Equal([]byte("foobar"))) Expect(frame.Fin).To(BeFalse()) Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) - Expect(r.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("respects the LEN when parsing the frame", func() { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(4)...) // data length data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, 0x8^0x2, protocol.Version1) + frame, l, err := parseStreamFrame(data, 0x8^0x2, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) Expect(frame.Data).To(Equal([]byte("foob"))) Expect(frame.Fin).To(BeFalse()) Expect(frame.Offset).To(BeZero()) - Expect(r.Len()).To(Equal(2)) + Expect(l).To(Equal(len(data) - 2)) }) It("parses a frame with FIN bit", func() { data := encodeVarInt(9) // stream ID data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, 0x8^0x1, protocol.Version1) + frame, l, err := parseStreamFrame(data, 0x8^0x1, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(9))) Expect(frame.Data).To(Equal([]byte("foobar"))) Expect(frame.Fin).To(BeTrue()) Expect(frame.Offset).To(BeZero()) - Expect(r.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("allows empty frames", func() { data := encodeVarInt(0x1337) // stream ID data = append(data, encodeVarInt(0x12345)...) // offset - r := bytes.NewReader(data) - f, err := parseStreamFrame(r, 0x8^0x4, protocol.Version1) + f, l, err := parseStreamFrame(data, 0x8^0x4, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(f.StreamID).To(Equal(protocol.StreamID(0x1337))) Expect(f.Offset).To(Equal(protocol.ByteCount(0x12345))) Expect(f.Data).To(BeEmpty()) Expect(f.Fin).To(BeFalse()) + Expect(l).To(Equal(len(data))) }) It("rejects frames that overflow the maximum offset", func() { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(uint64(protocol.MaxByteCount-5))...) // offset data = append(data, []byte("foobar")...) - r := bytes.NewReader(data) - _, err := parseStreamFrame(r, 0x8^0x4, protocol.Version1) + _, _, err := parseStreamFrame(data, 0x8^0x4, protocol.Version1) Expect(err).To(MatchError("stream data overflows maximum offset")) }) - It("rejects frames that claim to be longer than the packet size", func() { + It("rejects frames that claim to be longer than the packet buffer size", func() { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(uint64(protocol.MaxPacketBufferSize)+1)...) // data length data = append(data, make([]byte, protocol.MaxPacketBufferSize+1)...) - r := bytes.NewReader(data) - _, err := parseStreamFrame(r, 0x8^0x2, protocol.Version1) + _, _, err := parseStreamFrame(data, 0x8^0x2, protocol.Version1) + Expect(err).To(Equal(io.EOF)) + }) + + It("rejects frames that claim to be longer than the remaining size", func() { + data := encodeVarInt(0x12345) // stream ID + data = append(data, encodeVarInt(7)...) // data length + data = append(data, []byte("foobar")...) + _, _, err := parseStreamFrame(data, 0x8^0x2, protocol.Version1) Expect(err).To(Equal(io.EOF)) }) @@ -90,10 +93,10 @@ var _ = Describe("STREAM frame", func() { data = append(data, encodeVarInt(0xdecafbad)...) // offset data = append(data, encodeVarInt(6)...) // data length data = append(data, []byte("foobar")...) - _, err := parseStreamFrame(bytes.NewReader(data), typ, protocol.Version1) + _, _, err := parseStreamFrame(data, typ, protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err = parseStreamFrame(bytes.NewReader(data[:i]), typ, protocol.Version1) + _, _, err = parseStreamFrame(data[:i], typ, protocol.Version1) Expect(err).To(HaveOccurred()) } }) @@ -103,30 +106,28 @@ var _ = Describe("STREAM frame", func() { It("uses the buffer for long STREAM frames", func() { data := encodeVarInt(0x12345) // stream ID data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize)...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, 0x8, protocol.Version1) + frame, l, err := parseStreamFrame(data, 0x8, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize))) Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize)) Expect(frame.Fin).To(BeFalse()) Expect(frame.fromPool).To(BeTrue()) - Expect(r.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) Expect(frame.PutBack).ToNot(Panic()) }) It("doesn't use the buffer for short STREAM frames", func() { data := encodeVarInt(0x12345) // stream ID data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1)...) - r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, 0x8, protocol.Version1) + frame, l, err := parseStreamFrame(data, 0x8, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1))) Expect(frame.DataLen()).To(BeEquivalentTo(protocol.MinStreamFrameBufferSize - 1)) Expect(frame.Fin).To(BeFalse()) Expect(frame.fromPool).To(BeFalse()) - Expect(r.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) Expect(frame.PutBack).ToNot(Panic()) }) }) diff --git a/internal/wire/streams_blocked_frame.go b/internal/wire/streams_blocked_frame.go index 39e7474e..c946fec3 100644 --- a/internal/wire/streams_blocked_frame.go +++ b/internal/wire/streams_blocked_frame.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "fmt" "github.com/quic-go/quic-go/internal/protocol" @@ -14,7 +13,7 @@ type StreamsBlockedFrame struct { StreamLimit protocol.StreamNum } -func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, error) { +func parseStreamsBlockedFrame(b []byte, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, int, error) { f := &StreamsBlockedFrame{} switch typ { case bidiStreamBlockedFrameType: @@ -22,15 +21,15 @@ func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) ( case uniStreamBlockedFrameType: f.Type = protocol.StreamTypeUni } - streamLimit, err := quicvarint.Read(r) + streamLimit, l, err := quicvarint.Parse(b) if err != nil { - return nil, err + return nil, 0, replaceUnexpectedEOF(err) } f.StreamLimit = protocol.StreamNum(streamLimit) if f.StreamLimit > protocol.MaxStreamCount { - return nil, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit) + return nil, 0, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit) } - return f, nil + return f, l, nil } func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { diff --git a/internal/wire/streams_blocked_frame_test.go b/internal/wire/streams_blocked_frame_test.go index b02af10c..1e3fcb07 100644 --- a/internal/wire/streams_blocked_frame_test.go +++ b/internal/wire/streams_blocked_frame_test.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "fmt" "io" @@ -15,32 +14,30 @@ import ( var _ = Describe("STREAMS_BLOCKED frame", func() { Context("parsing", func() { It("accepts a frame for bidirectional streams", func() { - expected := encodeVarInt(0x1337) - b := bytes.NewReader(expected) - f, err := parseStreamsBlockedFrame(b, bidiStreamBlockedFrameType, protocol.Version1) + data := encodeVarInt(0x1337) + f, l, err := parseStreamsBlockedFrame(data, bidiStreamBlockedFrameType, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(f.Type).To(Equal(protocol.StreamTypeBidi)) Expect(f.StreamLimit).To(BeEquivalentTo(0x1337)) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("accepts a frame for unidirectional streams", func() { - expected := encodeVarInt(0x7331) - b := bytes.NewReader(expected) - f, err := parseStreamsBlockedFrame(b, uniStreamBlockedFrameType, protocol.Version1) + data := encodeVarInt(0x7331) + f, l, err := parseStreamsBlockedFrame(data, uniStreamBlockedFrameType, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(f.Type).To(Equal(protocol.StreamTypeUni)) Expect(f.StreamLimit).To(BeEquivalentTo(0x7331)) - Expect(b.Len()).To(BeZero()) + Expect(l).To(Equal(len(data))) }) It("errors on EOFs", func() { data := encodeVarInt(0x12345678) - b := bytes.NewReader(data) - _, err := parseStreamsBlockedFrame(b, bidiStreamBlockedFrameType, protocol.Version1) + _, l, err := parseStreamsBlockedFrame(data, bidiStreamBlockedFrameType, protocol.Version1) Expect(err).ToNot(HaveOccurred()) + Expect(l).To(Equal(len(data))) for i := range data { - _, err := parseStreamsBlockedFrame(bytes.NewReader(data[:i]), bidiStreamBlockedFrameType, protocol.Version1) + _, _, err := parseStreamsBlockedFrame(data[:i], bidiStreamBlockedFrameType, protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) @@ -55,12 +52,13 @@ var _ = Describe("STREAMS_BLOCKED frame", func() { } b, err := f.Append(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - r := bytes.NewReader(b) - typ, err := quicvarint.Read(r) + typ, l, err := quicvarint.Parse(b) Expect(err).ToNot(HaveOccurred()) - frame, err := parseStreamsBlockedFrame(r, typ, protocol.Version1) + b = b[l:] + frame, l, err := parseStreamsBlockedFrame(b, typ, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) + Expect(l).To(Equal(len(b))) }) It("errors when receiving a too large stream count", func() { @@ -70,10 +68,10 @@ var _ = Describe("STREAMS_BLOCKED frame", func() { } b, err := f.Append(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - r := bytes.NewReader(b) - typ, err := quicvarint.Read(r) + typ, l, err := quicvarint.Parse(b) Expect(err).ToNot(HaveOccurred()) - _, err = parseStreamsBlockedFrame(r, typ, protocol.Version1) + b = b[l:] + _, _, err = parseStreamsBlockedFrame(b, typ, protocol.Version1) Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) }) }