diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index 59a8459d..caecc25d 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -8,9 +8,18 @@ import ( "github.com/lucas-clemente/quic-go/internal/qerr" ) +type frameParser struct { + version protocol.VersionNumber +} + +// NewFrameParser creates a new frame parser. +func NewFrameParser(v protocol.VersionNumber) FrameParser { + return &frameParser{version: v} +} + // ParseNextFrame parses the next frame // It skips PADDING frames. -func ParseNextFrame(r *bytes.Reader, v protocol.VersionNumber) (Frame, error) { +func (p *frameParser) ParseNext(r *bytes.Reader) (Frame, error) { for r.Len() != 0 { typeByte, _ := r.ReadByte() if typeByte == 0x0 { // PADDING frame @@ -18,16 +27,16 @@ func ParseNextFrame(r *bytes.Reader, v protocol.VersionNumber) (Frame, error) { } r.UnreadByte() - return parseFrame(r, typeByte, v) + return p.parseFrame(r, typeByte) } return nil, nil } -func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame, error) { +func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte) (Frame, error) { var frame Frame var err error if typeByte&0xf8 == 0x8 { - frame, err = parseStreamFrame(r, v) + frame, err = parseStreamFrame(r, p.version) if err != nil { return nil, qerr.Error(qerr.InvalidFrameData, err.Error()) } @@ -35,39 +44,39 @@ func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame } switch typeByte { case 0x1: - frame, err = parsePingFrame(r, v) + frame, err = parsePingFrame(r, p.version) case 0x2, 0x3: - frame, err = parseAckFrame(r, v) + frame, err = parseAckFrame(r, p.version) case 0x4: - frame, err = parseResetStreamFrame(r, v) + frame, err = parseResetStreamFrame(r, p.version) case 0x5: - frame, err = parseStopSendingFrame(r, v) + frame, err = parseStopSendingFrame(r, p.version) case 0x6: - frame, err = parseCryptoFrame(r, v) + frame, err = parseCryptoFrame(r, p.version) case 0x7: - frame, err = parseNewTokenFrame(r, v) + frame, err = parseNewTokenFrame(r, p.version) case 0x10: - frame, err = parseMaxDataFrame(r, v) + frame, err = parseMaxDataFrame(r, p.version) case 0x11: - frame, err = parseMaxStreamDataFrame(r, v) + frame, err = parseMaxStreamDataFrame(r, p.version) case 0x12, 0x13: - frame, err = parseMaxStreamsFrame(r, v) + frame, err = parseMaxStreamsFrame(r, p.version) case 0x14: - frame, err = parseDataBlockedFrame(r, v) + frame, err = parseDataBlockedFrame(r, p.version) case 0x15: - frame, err = parseStreamDataBlockedFrame(r, v) + frame, err = parseStreamDataBlockedFrame(r, p.version) case 0x16, 0x17: - frame, err = parseStreamsBlockedFrame(r, v) + frame, err = parseStreamsBlockedFrame(r, p.version) case 0x18: - frame, err = parseNewConnectionIDFrame(r, v) + frame, err = parseNewConnectionIDFrame(r, p.version) case 0x19: - frame, err = parseRetireConnectionIDFrame(r, v) + frame, err = parseRetireConnectionIDFrame(r, p.version) case 0x1a: - frame, err = parsePathChallengeFrame(r, v) + frame, err = parsePathChallengeFrame(r, p.version) case 0x1b: - frame, err = parsePathResponseFrame(r, v) + frame, err = parsePathResponseFrame(r, p.version) case 0x1c, 0x1d: - frame, err = parseConnectionCloseFrame(r, v) + frame, err = parseConnectionCloseFrame(r, p.version) default: err = fmt.Errorf("unknown type byte 0x%x", typeByte) } diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index 5245c4e8..adca648d 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -10,14 +10,18 @@ import ( ) var _ = Describe("Frame parsing", func() { - var buf *bytes.Buffer + var ( + buf *bytes.Buffer + parser FrameParser + ) BeforeEach(func() { buf = &bytes.Buffer{} + parser = NewFrameParser(versionIETFFrames) }) It("returns nil if there's nothing more to read", func() { - f, err := ParseNextFrame(bytes.NewReader(nil), protocol.VersionWhatever) + f, err := parser.ParseNext(bytes.NewReader(nil)) Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeNil()) }) @@ -25,14 +29,14 @@ var _ = Describe("Frame parsing", func() { It("skips PADDING frames", func() { buf.Write([]byte{0}) // PADDING frame (&PingFrame{}).Write(buf, versionIETFFrames) - f, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + f, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(f).To(Equal(&PingFrame{})) }) It("handles PADDING at the end", func() { r := bytes.NewReader([]byte{0, 0, 0}) - f, err := ParseNextFrame(r, versionIETFFrames) + f, err := parser.ParseNext(r) Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeNil()) Expect(r.Len()).To(BeZero()) @@ -42,7 +46,7 @@ var _ = Describe("Frame parsing", func() { f := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 0x13}}} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) @@ -57,7 +61,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -67,7 +71,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -79,7 +83,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -89,7 +93,7 @@ var _ = Describe("Frame parsing", func() { f := &NewTokenFrame{Token: []byte("foobar")} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -104,7 +108,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -117,7 +121,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -130,7 +134,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -143,7 +147,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -153,7 +157,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -165,7 +169,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -178,7 +182,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -191,7 +195,7 @@ var _ = Describe("Frame parsing", func() { } buf := &bytes.Buffer{} Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -200,7 +204,7 @@ var _ = Describe("Frame parsing", func() { f := &RetireConnectionIDFrame{SequenceNumber: 0x1337} buf := &bytes.Buffer{} Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -209,7 +213,7 @@ var _ = Describe("Frame parsing", func() { f := &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) @@ -220,7 +224,7 @@ var _ = Describe("Frame parsing", func() { f := &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) @@ -235,13 +239,13 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) It("errors on invalid type", func() { - _, err := ParseNextFrame(bytes.NewReader([]byte{0x42}), versionIETFFrames) + _, err := parser.ParseNext(bytes.NewReader([]byte{0x42})) Expect(err).To(MatchError("InvalidFrameData: unknown type byte 0x42")) }) @@ -252,7 +256,7 @@ var _ = Describe("Frame parsing", func() { } b := &bytes.Buffer{} f.Write(b, versionIETFFrames) - _, err := ParseNextFrame(bytes.NewReader(b.Bytes()[:b.Len()-2]), versionIETFFrames) + _, err := parser.ParseNext(bytes.NewReader(b.Bytes()[:b.Len()-2])) Expect(err).To(HaveOccurred()) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidFrameData)) }) diff --git a/internal/wire/frame.go b/internal/wire/interface.go similarity index 67% rename from internal/wire/frame.go rename to internal/wire/interface.go index 835905a4..d9837a6a 100644 --- a/internal/wire/frame.go +++ b/internal/wire/interface.go @@ -11,3 +11,8 @@ type Frame interface { Write(b *bytes.Buffer, version protocol.VersionNumber) error Length(version protocol.VersionNumber) protocol.ByteCount } + +// A FrameParser parses QUIC frames, one by one. +type FrameParser interface { + ParseNext(r *bytes.Reader) (Frame, error) +} diff --git a/packet_packer_test.go b/packet_packer_test.go index ede87fd4..55a7f193 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -821,7 +821,8 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(firstPayloadByte).To(Equal(byte(0))) // ... followed by the stream frame - frame, err := wire.ParseNextFrame(r, packer.version) + frameParser := wire.NewFrameParser(packer.version) + frame, err := frameParser.ParseNext(r) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(r.Len()).To(BeZero()) diff --git a/session.go b/session.go index 6b5dc0fa..4122d670 100644 --- a/session.go +++ b/session.go @@ -93,8 +93,9 @@ type session struct { windowUpdateQueue *windowUpdateQueue connFlowController flowcontrol.ConnectionFlowController - unpacker unpacker - packer packer + unpacker unpacker + frameParser wire.FrameParser + packer packer cryptoStreamHandler cryptoStreamHandler @@ -292,6 +293,7 @@ var newClientSession = func( } func (s *session) preSetup() { + s.frameParser = wire.NewFrameParser(s.version) s.rttStats = &congestion.RTTStats{} s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) s.connFlowController = flowcontrol.NewConnectionFlowController( @@ -551,7 +553,7 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time r := bytes.NewReader(packet.data) var isRetransmittable bool for { - frame, err := wire.ParseNextFrame(r, s.version) + frame, err := s.frameParser.ParseNext(r) if err != nil { return err }