From 741521c4d1f0800543207c8e2af4d5744aa8a424 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 7 Jan 2019 14:17:31 +0700 Subject: [PATCH 1/5] refactor frame parsing into a separate struct --- internal/wire/frame_parser.go | 51 ++++++++++++++---------- internal/wire/frame_parser_test.go | 50 ++++++++++++----------- internal/wire/{frame.go => interface.go} | 5 +++ packet_packer_test.go | 3 +- session.go | 8 ++-- 5 files changed, 69 insertions(+), 48 deletions(-) rename internal/wire/{frame.go => interface.go} (67%) diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index 59a8459d6..caecc25d7 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -8,9 +8,18 @@ import ( "github.com/lucas-clemente/quic-go/internal/qerr" ) +type frameParser struct { + version protocol.VersionNumber +} + +// NewFrameParser creates a new frame parser. +func NewFrameParser(v protocol.VersionNumber) FrameParser { + return &frameParser{version: v} +} + // ParseNextFrame parses the next frame // It skips PADDING frames. -func ParseNextFrame(r *bytes.Reader, v protocol.VersionNumber) (Frame, error) { +func (p *frameParser) ParseNext(r *bytes.Reader) (Frame, error) { for r.Len() != 0 { typeByte, _ := r.ReadByte() if typeByte == 0x0 { // PADDING frame @@ -18,16 +27,16 @@ func ParseNextFrame(r *bytes.Reader, v protocol.VersionNumber) (Frame, error) { } r.UnreadByte() - return parseFrame(r, typeByte, v) + return p.parseFrame(r, typeByte) } return nil, nil } -func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame, error) { +func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte) (Frame, error) { var frame Frame var err error if typeByte&0xf8 == 0x8 { - frame, err = parseStreamFrame(r, v) + frame, err = parseStreamFrame(r, p.version) if err != nil { return nil, qerr.Error(qerr.InvalidFrameData, err.Error()) } @@ -35,39 +44,39 @@ func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame } switch typeByte { case 0x1: - frame, err = parsePingFrame(r, v) + frame, err = parsePingFrame(r, p.version) case 0x2, 0x3: - frame, err = parseAckFrame(r, v) + frame, err = parseAckFrame(r, p.version) case 0x4: - frame, err = parseResetStreamFrame(r, v) + frame, err = parseResetStreamFrame(r, p.version) case 0x5: - frame, err = parseStopSendingFrame(r, v) + frame, err = parseStopSendingFrame(r, p.version) case 0x6: - frame, err = parseCryptoFrame(r, v) + frame, err = parseCryptoFrame(r, p.version) case 0x7: - frame, err = parseNewTokenFrame(r, v) + frame, err = parseNewTokenFrame(r, p.version) case 0x10: - frame, err = parseMaxDataFrame(r, v) + frame, err = parseMaxDataFrame(r, p.version) case 0x11: - frame, err = parseMaxStreamDataFrame(r, v) + frame, err = parseMaxStreamDataFrame(r, p.version) case 0x12, 0x13: - frame, err = parseMaxStreamsFrame(r, v) + frame, err = parseMaxStreamsFrame(r, p.version) case 0x14: - frame, err = parseDataBlockedFrame(r, v) + frame, err = parseDataBlockedFrame(r, p.version) case 0x15: - frame, err = parseStreamDataBlockedFrame(r, v) + frame, err = parseStreamDataBlockedFrame(r, p.version) case 0x16, 0x17: - frame, err = parseStreamsBlockedFrame(r, v) + frame, err = parseStreamsBlockedFrame(r, p.version) case 0x18: - frame, err = parseNewConnectionIDFrame(r, v) + frame, err = parseNewConnectionIDFrame(r, p.version) case 0x19: - frame, err = parseRetireConnectionIDFrame(r, v) + frame, err = parseRetireConnectionIDFrame(r, p.version) case 0x1a: - frame, err = parsePathChallengeFrame(r, v) + frame, err = parsePathChallengeFrame(r, p.version) case 0x1b: - frame, err = parsePathResponseFrame(r, v) + frame, err = parsePathResponseFrame(r, p.version) case 0x1c, 0x1d: - frame, err = parseConnectionCloseFrame(r, v) + frame, err = parseConnectionCloseFrame(r, p.version) default: err = fmt.Errorf("unknown type byte 0x%x", typeByte) } diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index 5245c4e87..adca648da 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -10,14 +10,18 @@ import ( ) var _ = Describe("Frame parsing", func() { - var buf *bytes.Buffer + var ( + buf *bytes.Buffer + parser FrameParser + ) BeforeEach(func() { buf = &bytes.Buffer{} + parser = NewFrameParser(versionIETFFrames) }) It("returns nil if there's nothing more to read", func() { - f, err := ParseNextFrame(bytes.NewReader(nil), protocol.VersionWhatever) + f, err := parser.ParseNext(bytes.NewReader(nil)) Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeNil()) }) @@ -25,14 +29,14 @@ var _ = Describe("Frame parsing", func() { It("skips PADDING frames", func() { buf.Write([]byte{0}) // PADDING frame (&PingFrame{}).Write(buf, versionIETFFrames) - f, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + f, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(f).To(Equal(&PingFrame{})) }) It("handles PADDING at the end", func() { r := bytes.NewReader([]byte{0, 0, 0}) - f, err := ParseNextFrame(r, versionIETFFrames) + f, err := parser.ParseNext(r) Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeNil()) Expect(r.Len()).To(BeZero()) @@ -42,7 +46,7 @@ var _ = Describe("Frame parsing", func() { f := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 0x13}}} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) @@ -57,7 +61,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -67,7 +71,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -79,7 +83,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -89,7 +93,7 @@ var _ = Describe("Frame parsing", func() { f := &NewTokenFrame{Token: []byte("foobar")} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -104,7 +108,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -117,7 +121,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -130,7 +134,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -143,7 +147,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -153,7 +157,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -165,7 +169,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -178,7 +182,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -191,7 +195,7 @@ var _ = Describe("Frame parsing", func() { } buf := &bytes.Buffer{} Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -200,7 +204,7 @@ var _ = Describe("Frame parsing", func() { f := &RetireConnectionIDFrame{SequenceNumber: 0x1337} buf := &bytes.Buffer{} Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -209,7 +213,7 @@ var _ = Describe("Frame parsing", func() { f := &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) @@ -220,7 +224,7 @@ var _ = Describe("Frame parsing", func() { f := &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) @@ -235,13 +239,13 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := ParseNextFrame(bytes.NewReader(buf.Bytes()), versionIETFFrames) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) It("errors on invalid type", func() { - _, err := ParseNextFrame(bytes.NewReader([]byte{0x42}), versionIETFFrames) + _, err := parser.ParseNext(bytes.NewReader([]byte{0x42})) Expect(err).To(MatchError("InvalidFrameData: unknown type byte 0x42")) }) @@ -252,7 +256,7 @@ var _ = Describe("Frame parsing", func() { } b := &bytes.Buffer{} f.Write(b, versionIETFFrames) - _, err := ParseNextFrame(bytes.NewReader(b.Bytes()[:b.Len()-2]), versionIETFFrames) + _, err := parser.ParseNext(bytes.NewReader(b.Bytes()[:b.Len()-2])) Expect(err).To(HaveOccurred()) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidFrameData)) }) diff --git a/internal/wire/frame.go b/internal/wire/interface.go similarity index 67% rename from internal/wire/frame.go rename to internal/wire/interface.go index 835905a41..d9837a6a2 100644 --- a/internal/wire/frame.go +++ b/internal/wire/interface.go @@ -11,3 +11,8 @@ type Frame interface { Write(b *bytes.Buffer, version protocol.VersionNumber) error Length(version protocol.VersionNumber) protocol.ByteCount } + +// A FrameParser parses QUIC frames, one by one. +type FrameParser interface { + ParseNext(r *bytes.Reader) (Frame, error) +} diff --git a/packet_packer_test.go b/packet_packer_test.go index ede87fd4b..55a7f1935 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -821,7 +821,8 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(firstPayloadByte).To(Equal(byte(0))) // ... followed by the stream frame - frame, err := wire.ParseNextFrame(r, packer.version) + frameParser := wire.NewFrameParser(packer.version) + frame, err := frameParser.ParseNext(r) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(r.Len()).To(BeZero()) diff --git a/session.go b/session.go index 6b5dc0fab..4122d6700 100644 --- a/session.go +++ b/session.go @@ -93,8 +93,9 @@ type session struct { windowUpdateQueue *windowUpdateQueue connFlowController flowcontrol.ConnectionFlowController - unpacker unpacker - packer packer + unpacker unpacker + frameParser wire.FrameParser + packer packer cryptoStreamHandler cryptoStreamHandler @@ -292,6 +293,7 @@ var newClientSession = func( } func (s *session) preSetup() { + s.frameParser = wire.NewFrameParser(s.version) s.rttStats = &congestion.RTTStats{} s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) s.connFlowController = flowcontrol.NewConnectionFlowController( @@ -551,7 +553,7 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time r := bytes.NewReader(packet.data) var isRetransmittable bool for { - frame, err := wire.ParseNextFrame(r, s.version) + frame, err := s.frameParser.ParseNext(r) if err != nil { return err } From ee75f5e2f224a6125b42dacec672e16442abba5a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 7 Jan 2019 14:32:57 +0700 Subject: [PATCH 2/5] implement ACK frame parsing using an ack delay exponent --- internal/protocol/params.go | 3 +++ internal/wire/ack_frame.go | 7 ++--- internal/wire/ack_frame_test.go | 48 ++++++++++++++++++++++----------- internal/wire/frame_parser.go | 2 +- 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/internal/protocol/params.go b/internal/protocol/params.go index e6f9493fa..7077d825b 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -117,3 +117,6 @@ const MinPacingDelay time.Duration = 100 * time.Microsecond // DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections // if no other value is configured. const DefaultConnectionIDLength = 4 + +// AckDelayExponent is the ack delay exponent used when sending ACKs. +const AckDelayExponent = 3 diff --git a/internal/wire/ack_frame.go b/internal/wire/ack_frame.go index e2e8f471f..7909bffa8 100644 --- a/internal/wire/ack_frame.go +++ b/internal/wire/ack_frame.go @@ -10,9 +10,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) -// TODO: use the value sent in the transport parameters -const ackDelayExponent = 3 - var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") // An AckFrame is an ACK frame @@ -22,7 +19,7 @@ type AckFrame struct { } // parseAckFrame reads an ACK frame -func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) { +func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, version protocol.VersionNumber) (*AckFrame, error) { typeByte, err := r.ReadByte() if err != nil { return nil, err @@ -225,5 +222,5 @@ func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { } func encodeAckDelay(delay time.Duration) uint64 { - return uint64(delay.Nanoseconds() / (1000 * (1 << ackDelayExponent))) + return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent))) } diff --git a/internal/wire/ack_frame_test.go b/internal/wire/ack_frame_test.go index 65a06feeb..41d7382db 100644 --- a/internal/wire/ack_frame_test.go +++ b/internal/wire/ack_frame_test.go @@ -19,7 +19,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(10)...) // first ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) @@ -34,7 +34,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(0)...) // first ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(55))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(55))) @@ -49,7 +49,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(20)...) // first ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(20))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(0))) @@ -64,7 +64,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(21)...) // first ack block b := bytes.NewReader(data) - _, err := parseAckFrame(b, versionIETFFrames) + _, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) Expect(err).To(MatchError("invalid first ACK range")) }) @@ -77,7 +77,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(98)...) // gap data = append(data, encodeVarInt(50)...) // ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(1000))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(750))) @@ -100,7 +100,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(1)...) // gap data = append(data, encodeVarInt(1)...) // ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(94))) @@ -113,6 +113,22 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { Expect(b.Len()).To(BeZero()) }) + It("uses the ack delay exponent", func() { + const delayTime = 1 << 10 * time.Millisecond + buf := &bytes.Buffer{} + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: delayTime, + } + Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + for i := uint8(0); i < 8; i++ { + b := bytes.NewReader(buf.Bytes()) + frame, err := parseAckFrame(b, protocol.AckDelayExponent+i, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.DelayTime).To(Equal(delayTime * (1 << i))) + } + }) + It("errors on EOF", func() { data := []byte{0x2} data = append(data, encodeVarInt(1000)...) // largest acked @@ -121,10 +137,10 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(100)...) // first ack block data = append(data, encodeVarInt(98)...) // gap data = append(data, encodeVarInt(50)...) // ack block - _, err := parseAckFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, versionIETFFrames) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseAckFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, versionIETFFrames) Expect(err).To(MatchError(io.EOF)) } }) @@ -140,7 +156,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0x12345)...) // ECT(1) data = append(data, encodeVarInt(0x12345678)...) // ECN-CE b := bytes.NewReader(data) - frame, err := parseAckFrame(b, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) @@ -159,10 +175,10 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0x42)...) // ECT(0) data = append(data, encodeVarInt(0x12345)...) // ECT(1) data = append(data, encodeVarInt(0x12345678)...) // ECN-CE - _, err := parseAckFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, versionIETFFrames) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseAckFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, versionIETFFrames) Expect(err).To(MatchError(io.EOF)) } }) @@ -196,7 +212,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { Expect(err).ToNot(HaveOccurred()) Expect(f.Length(versionIETFFrames)).To(BeEquivalentTo(buf.Len())) b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeFalse()) @@ -213,7 +229,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { Expect(err).ToNot(HaveOccurred()) Expect(f.Length(versionIETFFrames)).To(BeEquivalentTo(buf.Len())) b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeFalse()) @@ -233,7 +249,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { Expect(err).ToNot(HaveOccurred()) Expect(f.Length(versionIETFFrames)).To(BeEquivalentTo(buf.Len())) b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeTrue()) @@ -255,7 +271,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { Expect(err).ToNot(HaveOccurred()) Expect(f.Length(versionIETFFrames)).To(BeEquivalentTo(buf.Len())) b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeTrue()) @@ -278,7 +294,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { Expect(buf.Len()).To(BeNumerically(">", protocol.MaxAckFrameSize-5)) Expect(buf.Len()).To(BeNumerically("<=", protocol.MaxAckFrameSize)) b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.HasMissingRanges()).To(BeTrue()) Expect(b.Len()).To(BeZero()) diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index caecc25d7..5dedf1f55 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -46,7 +46,7 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte) (Frame, error) case 0x1: frame, err = parsePingFrame(r, p.version) case 0x2, 0x3: - frame, err = parseAckFrame(r, p.version) + frame, err = parseAckFrame(r, protocol.AckDelayExponent, p.version) case 0x4: frame, err = parseResetStreamFrame(r, p.version) case 0x5: From 155ebd18a225c81b3332879d335727ae607a2f43 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 7 Jan 2019 15:09:14 +0700 Subject: [PATCH 3/5] parse and send the ack_delay_exponent in the transport parameters --- client.go | 1 + .../handshake/transport_parameter_test.go | 21 ++++++++++++++++++- internal/handshake/transport_parameters.go | 20 ++++++++++++++++-- internal/protocol/protocol.go | 3 +++ server.go | 1 + 5 files changed, 43 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index e4e819ed2..2219112c2 100644 --- a/client.go +++ b/client.go @@ -380,6 +380,7 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error { IdleTimeout: c.config.IdleTimeout, MaxBidiStreams: uint64(c.config.MaxIncomingStreams), MaxUniStreams: uint64(c.config.MaxIncomingUniStreams), + AckDelayExponent: protocol.AckDelayExponent, DisableMigration: true, } diff --git a/internal/handshake/transport_parameter_test.go b/internal/handshake/transport_parameter_test.go index 7d54d7b54..4247cf524 100644 --- a/internal/handshake/transport_parameter_test.go +++ b/internal/handshake/transport_parameter_test.go @@ -23,8 +23,9 @@ var _ = Describe("Transport Parameters", func() { MaxUniStreams: 7331, IdleTimeout: 42 * time.Second, OriginalConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + AckDelayExponent: 14, } - Expect(p.String()).To(Equal("&handshake.TransportParameters{OriginalConnectionID: 0xdeadbeef, InitialMaxStreamDataBidiLocal: 0x1234, InitialMaxStreamDataBidiRemote: 0x2345, InitialMaxStreamDataUni: 0x3456, InitialMaxData: 0x4567, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s}")) + Expect(p.String()).To(Equal("&handshake.TransportParameters{OriginalConnectionID: 0xdeadbeef, InitialMaxStreamDataBidiLocal: 0x1234, InitialMaxStreamDataBidiRemote: 0x2345, InitialMaxStreamDataUni: 0x3456, InitialMaxData: 0x4567, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s, AckDelayExponent: 14}")) }) getRandomValue := func() uint64 { @@ -45,6 +46,7 @@ var _ = Describe("Transport Parameters", func() { DisableMigration: true, StatelessResetToken: bytes.Repeat([]byte{100}, 16), OriginalConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + AckDelayExponent: 13, } b := &bytes.Buffer{} params.marshal(b) @@ -61,6 +63,7 @@ var _ = Describe("Transport Parameters", func() { Expect(p.DisableMigration).To(Equal(params.DisableMigration)) Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken)) Expect(p.OriginalConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) + Expect(p.AckDelayExponent).To(Equal(uint8(13))) }) It("errors when the stateless_reset_token has the wrong length", func() { @@ -89,6 +92,22 @@ var _ = Describe("Transport Parameters", func() { Expect(p.unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("wrong length for disable_migration: 6 (expected empty)")) }) + It("errors when the ack_delay_exponenent is too large", func() { + b := &bytes.Buffer{} + (&TransportParameters{AckDelayExponent: 21}).marshal(b) + p := &TransportParameters{} + Expect(p.unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("invalid value for ack_delay_exponent: 21 (maximum 20)")) + }) + + It("doesn't send the ack_delay_exponent, if it has the default value", func() { + b := &bytes.Buffer{} + (&TransportParameters{AckDelayExponent: protocol.DefaultAckDelayExponent}).marshal(b) + defaultLen := b.Len() + b.Reset() + (&TransportParameters{AckDelayExponent: protocol.DefaultAckDelayExponent + 1}).marshal(b) + Expect(b.Len()).To(Equal(defaultLen + 2 /* parameter ID */ + 2 /* length field */ + 1 /* value */)) + }) + It("errors when the varint value has the wrong length", func() { b := &bytes.Buffer{} utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataBidiLocalParameterID)) diff --git a/internal/handshake/transport_parameters.go b/internal/handshake/transport_parameters.go index 500be0084..3b265bc81 100644 --- a/internal/handshake/transport_parameters.go +++ b/internal/handshake/transport_parameters.go @@ -25,6 +25,7 @@ const ( initialMaxStreamDataUniParameterID transportParameterID = 0x7 initialMaxStreamsBidiParameterID transportParameterID = 0x8 initialMaxStreamsUniParameterID transportParameterID = 0x9 + ackDelayExponentParameterID transportParameterID = 0xa disableMigrationParameterID transportParameterID = 0xc ) @@ -35,6 +36,8 @@ type TransportParameters struct { InitialMaxStreamDataUni protocol.ByteCount InitialMaxData protocol.ByteCount + AckDelayExponent uint8 + MaxPacketSize protocol.ByteCount MaxUniStreams uint64 @@ -65,7 +68,8 @@ func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective initialMaxStreamsBidiParameterID, initialMaxStreamsUniParameterID, idleTimeoutParameterID, - maxPacketSizeParameterID: + maxPacketSizeParameterID, + ackDelayExponentParameterID: if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil { return err } @@ -147,6 +151,11 @@ func (p *TransportParameters) readNumericTransportParameter( return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val) } p.MaxPacketSize = protocol.ByteCount(val) + case ackDelayExponentParameterID: + if val > 20 { + return fmt.Errorf("invalid value for ack_delay_exponent: %d (maximum 20)", val) + } + p.AckDelayExponent = uint8(val) default: return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID) } @@ -186,6 +195,13 @@ func (p *TransportParameters) marshal(b *bytes.Buffer) { utils.BigEndian.WriteUint16(b, uint16(maxPacketSizeParameterID)) utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(protocol.MaxReceivePacketSize)))) utils.WriteVarInt(b, uint64(protocol.MaxReceivePacketSize)) + // ack_delay_exponent + // Only send it if is different from the default value. + if p.AckDelayExponent != protocol.DefaultAckDelayExponent { + utils.BigEndian.WriteUint16(b, uint16(ackDelayExponentParameterID)) + utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.AckDelayExponent)))) + utils.WriteVarInt(b, uint64(p.AckDelayExponent)) + } // disable_migration if p.DisableMigration { utils.BigEndian.WriteUint16(b, uint16(disableMigrationParameterID)) @@ -206,5 +222,5 @@ func (p *TransportParameters) marshal(b *bytes.Buffer) { // String returns a string representation, intended for logging. func (p *TransportParameters) String() string { - return fmt.Sprintf("&handshake.TransportParameters{OriginalConnectionID: %s, InitialMaxStreamDataBidiLocal: %#x, InitialMaxStreamDataBidiRemote: %#x, InitialMaxStreamDataUni: %#x, InitialMaxData: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.OriginalConnectionID, p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout) + return fmt.Sprintf("&handshake.TransportParameters{OriginalConnectionID: %s, InitialMaxStreamDataBidiLocal: %#x, InitialMaxStreamDataBidiRemote: %#x, InitialMaxStreamDataUni: %#x, InitialMaxData: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s, AckDelayExponent: %d}", p.OriginalConnectionID, p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout, p.AckDelayExponent) } diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index d190287d4..c50d6e78b 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -63,3 +63,6 @@ const MinStatelessResetSize = 1 /* first byte */ + 22 /* random bytes */ + 16 /* // MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet. const MinConnectionIDLenInitial = 8 + +// DefaultAckDelayExponent is the default ack delay exponent +const DefaultAckDelayExponent = 3 diff --git a/server.go b/server.go index 9b106d2e8..9489fb13a 100644 --- a/server.go +++ b/server.go @@ -427,6 +427,7 @@ func (s *server) createNewSession( IdleTimeout: s.config.IdleTimeout, MaxBidiStreams: uint64(s.config.MaxIncomingStreams), MaxUniStreams: uint64(s.config.MaxIncomingUniStreams), + AckDelayExponent: protocol.AckDelayExponent, DisableMigration: true, // TODO(#855): generate a real token StatelessResetToken: bytes.Repeat([]byte{42}, 16), From cebb4342ecf80acf0e932b5a66cbfeecaab6190c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 7 Jan 2019 15:37:38 +0700 Subject: [PATCH 4/5] use the ack_delay_exponent sent in the transport parameters --- internal/wire/frame_parser.go | 18 ++++++-- internal/wire/frame_parser_test.go | 71 +++++++++++++++++++++--------- internal/wire/interface.go | 3 +- packet_packer_test.go | 2 +- session.go | 3 +- 5 files changed, 68 insertions(+), 29 deletions(-) diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index 5dedf1f55..e0c946cbd 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -9,6 +9,8 @@ import ( ) type frameParser struct { + ackDelayExponent uint8 + version protocol.VersionNumber } @@ -19,7 +21,7 @@ func NewFrameParser(v protocol.VersionNumber) FrameParser { // ParseNextFrame parses the next frame // It skips PADDING frames. -func (p *frameParser) ParseNext(r *bytes.Reader) (Frame, error) { +func (p *frameParser) ParseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel) (Frame, error) { for r.Len() != 0 { typeByte, _ := r.ReadByte() if typeByte == 0x0 { // PADDING frame @@ -27,12 +29,12 @@ func (p *frameParser) ParseNext(r *bytes.Reader) (Frame, error) { } r.UnreadByte() - return p.parseFrame(r, typeByte) + return p.parseFrame(r, typeByte, encLevel) } return nil, nil } -func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte) (Frame, error) { +func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protocol.EncryptionLevel) (Frame, error) { var frame Frame var err error if typeByte&0xf8 == 0x8 { @@ -46,7 +48,11 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte) (Frame, error) case 0x1: frame, err = parsePingFrame(r, p.version) case 0x2, 0x3: - frame, err = parseAckFrame(r, protocol.AckDelayExponent, p.version) + ackDelayExponent := p.ackDelayExponent + if encLevel != protocol.Encryption1RTT { + ackDelayExponent = protocol.DefaultAckDelayExponent + } + frame, err = parseAckFrame(r, ackDelayExponent, p.version) case 0x4: frame, err = parseResetStreamFrame(r, p.version) case 0x5: @@ -85,3 +91,7 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte) (Frame, error) } return frame, nil } + +func (p *frameParser) SetAckDelayExponent(exp uint8) { + p.ackDelayExponent = exp +} diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index adca648da..44a8987a1 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "time" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" @@ -21,7 +22,7 @@ var _ = Describe("Frame parsing", func() { }) It("returns nil if there's nothing more to read", func() { - f, err := parser.ParseNext(bytes.NewReader(nil)) + f, err := parser.ParseNext(bytes.NewReader(nil), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeNil()) }) @@ -29,14 +30,14 @@ var _ = Describe("Frame parsing", func() { It("skips PADDING frames", func() { buf.Write([]byte{0}) // PADDING frame (&PingFrame{}).Write(buf, versionIETFFrames) - f, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + f, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(f).To(Equal(&PingFrame{})) }) It("handles PADDING at the end", func() { r := bytes.NewReader([]byte{0, 0, 0}) - f, err := parser.ParseNext(r) + f, err := parser.ParseNext(r, protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeNil()) Expect(r.Len()).To(BeZero()) @@ -46,13 +47,39 @@ var _ = Describe("Frame parsing", func() { f := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 0x13}}} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) Expect(frame.(*AckFrame).LargestAcked()).To(Equal(protocol.PacketNumber(0x13))) }) + It("uses the custom ack delay exponent for 1RTT packets", func() { + parser.SetAckDelayExponent(protocol.AckDelayExponent + 2) + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: time.Second, + } + Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + // The ACK frame is always written using the protocol.AckDelayExponent. + // That's why we expect a different value when parsing. + Expect(frame.(*AckFrame).DelayTime).To(Equal(4 * time.Second)) + }) + + It("uses the default ack delay exponent for non-1RTT packets", func() { + parser.SetAckDelayExponent(protocol.AckDelayExponent + 2) + f := &AckFrame{ + AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, + DelayTime: time.Second, + } + Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.(*AckFrame).DelayTime).To(Equal(time.Second)) + }) + It("unpacks RESET_STREAM frames", func() { f := &ResetStreamFrame{ StreamID: 0xdeadbeef, @@ -61,7 +88,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -71,7 +98,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -83,7 +110,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -93,7 +120,7 @@ var _ = Describe("Frame parsing", func() { f := &NewTokenFrame{Token: []byte("foobar")} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -108,7 +135,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -121,7 +148,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -134,7 +161,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -147,7 +174,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -157,7 +184,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -169,7 +196,7 @@ var _ = Describe("Frame parsing", func() { } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -182,7 +209,7 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -195,7 +222,7 @@ var _ = Describe("Frame parsing", func() { } buf := &bytes.Buffer{} Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -204,7 +231,7 @@ var _ = Describe("Frame parsing", func() { f := &RetireConnectionIDFrame{SequenceNumber: 0x1337} buf := &bytes.Buffer{} Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -213,7 +240,7 @@ var _ = Describe("Frame parsing", func() { f := &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) @@ -224,7 +251,7 @@ var _ = Describe("Frame parsing", func() { f := &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) @@ -239,13 +266,13 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes())) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) It("errors on invalid type", func() { - _, err := parser.ParseNext(bytes.NewReader([]byte{0x42})) + _, err := parser.ParseNext(bytes.NewReader([]byte{0x42}), protocol.Encryption1RTT) Expect(err).To(MatchError("InvalidFrameData: unknown type byte 0x42")) }) @@ -256,7 +283,7 @@ var _ = Describe("Frame parsing", func() { } b := &bytes.Buffer{} f.Write(b, versionIETFFrames) - _, err := parser.ParseNext(bytes.NewReader(b.Bytes()[:b.Len()-2])) + _, err := parser.ParseNext(bytes.NewReader(b.Bytes()[:b.Len()-2]), protocol.Encryption1RTT) Expect(err).To(HaveOccurred()) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidFrameData)) }) diff --git a/internal/wire/interface.go b/internal/wire/interface.go index d9837a6a2..99fdc80fb 100644 --- a/internal/wire/interface.go +++ b/internal/wire/interface.go @@ -14,5 +14,6 @@ type Frame interface { // A FrameParser parses QUIC frames, one by one. type FrameParser interface { - ParseNext(r *bytes.Reader) (Frame, error) + ParseNext(*bytes.Reader, protocol.EncryptionLevel) (Frame, error) + SetAckDelayExponent(uint8) } diff --git a/packet_packer_test.go b/packet_packer_test.go index 55a7f1935..dda434bd2 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -822,7 +822,7 @@ var _ = Describe("Packet packer", func() { Expect(firstPayloadByte).To(Equal(byte(0))) // ... followed by the stream frame frameParser := wire.NewFrameParser(packer.version) - frame, err := frameParser.ParseNext(r) + frame, err := frameParser.ParseNext(r, protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(r.Len()).To(BeZero()) diff --git a/session.go b/session.go index 4122d6700..ba79ef65d 100644 --- a/session.go +++ b/session.go @@ -553,7 +553,7 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time r := bytes.NewReader(packet.data) var isRetransmittable bool for { - frame, err := s.frameParser.ParseNext(r) + frame, err := s.frameParser.ParseNext(r, packet.encryptionLevel) if err != nil { return err } @@ -814,6 +814,7 @@ func (s *session) processTransportParameters(params *handshake.TransportParamete s.peerParams = params s.streamsMap.UpdateLimits(params) s.packer.HandleTransportParameters(params) + s.frameParser.SetAckDelayExponent(params.AckDelayExponent) s.connFlowController.UpdateSendWindow(params.InitialMaxData) // the crypto stream is the only open stream at this moment // so we don't need to update stream flow control windows From 6834c37462ac6296cc80307b2707362cc83558c3 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 28 Jan 2019 16:37:00 +0900 Subject: [PATCH 5/5] move the maximum ack delay exponennt to the protocol constants --- internal/handshake/transport_parameters.go | 4 ++-- internal/protocol/protocol.go | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/internal/handshake/transport_parameters.go b/internal/handshake/transport_parameters.go index 3b265bc81..5bc0695dc 100644 --- a/internal/handshake/transport_parameters.go +++ b/internal/handshake/transport_parameters.go @@ -152,8 +152,8 @@ func (p *TransportParameters) readNumericTransportParameter( } p.MaxPacketSize = protocol.ByteCount(val) case ackDelayExponentParameterID: - if val > 20 { - return fmt.Errorf("invalid value for ack_delay_exponent: %d (maximum 20)", val) + if val > protocol.MaxAckDelayExponent { + return fmt.Errorf("invalid value for ack_delay_exponent: %d (maximum %d)", val, protocol.MaxAckDelayExponent) } p.AckDelayExponent = uint8(val) default: diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index c50d6e78b..c18a2b65c 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -66,3 +66,6 @@ const MinConnectionIDLenInitial = 8 // DefaultAckDelayExponent is the default ack delay exponent const DefaultAckDelayExponent = 3 + +// MaxAckDelayExponent is the maximum ack delay exponent +const MaxAckDelayExponent = 20