diff --git a/internal/wire/blocked_frame_legacy.go b/internal/wire/blocked_frame_legacy.go index d60ca4c4..19274cf4 100644 --- a/internal/wire/blocked_frame_legacy.go +++ b/internal/wire/blocked_frame_legacy.go @@ -16,8 +16,7 @@ type blockedFrameLegacy struct { // * a STREAM_BLOCKED frame, if the BLOCKED applies to a stream // * a BLOCKED frame, if the BLOCKED applies to the connection func ParseBlockedFrameLegacy(r *bytes.Reader, version protocol.VersionNumber) (Frame, error) { - // read the TypeByte - if _, err := r.ReadByte(); err != nil { + if _, err := r.ReadByte(); err != nil { // read the TypeByte return nil, err } streamID, err := utils.GetByteOrder(version).ReadUint32(r) diff --git a/internal/wire/blocked_frame_legacy_test.go b/internal/wire/blocked_frame_legacy_test.go index 6d959480..6084ea69 100644 --- a/internal/wire/blocked_frame_legacy_test.go +++ b/internal/wire/blocked_frame_legacy_test.go @@ -10,60 +10,56 @@ import ( var _ = Describe("legacy BLOCKED Frame", func() { Context("when parsing", func() { - Context("in big endian", func() { - It("accepts sample frame for a stream", func() { - b := bytes.NewReader([]byte{0x5, 0xde, 0xad, 0xbe, 0xef}) - f, err := ParseBlockedFrameLegacy(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(BeAssignableToTypeOf(&StreamBlockedFrame{})) - frame := f.(*StreamBlockedFrame) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) - }) - - It("accepts sample frame for the connection", func() { - b := bytes.NewReader([]byte{0x5, 0x0, 0x0, 0x0, 0x0}) - f, err := ParseBlockedFrameLegacy(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(BeAssignableToTypeOf(&BlockedFrame{})) - }) + It("accepts sample frame for a stream", func() { + b := bytes.NewReader([]byte{0x5, 0xde, 0xad, 0xbe, 0xef}) + f, err := ParseBlockedFrameLegacy(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeAssignableToTypeOf(&StreamBlockedFrame{})) + frame := f.(*StreamBlockedFrame) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) }) - It("errors on EOFs", func() { - data := []byte{0x5, 0xef, 0xbe, 0xad, 0xde} - _, err := ParseBlockedFrameLegacy(bytes.NewReader(data), protocol.VersionWhatever) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := ParseBlockedFrameLegacy(bytes.NewReader(data[0:i]), protocol.VersionWhatever) - Expect(err).To(HaveOccurred()) - } + It("accepts sample frame for the connection", func() { + b := bytes.NewReader([]byte{0x5, 0x0, 0x0, 0x0, 0x0}) + f, err := ParseBlockedFrameLegacy(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeAssignableToTypeOf(&BlockedFrame{})) }) }) + It("errors on EOFs", func() { + data := []byte{0x5, 0xef, 0xbe, 0xad, 0xde} + _, err := ParseBlockedFrameLegacy(bytes.NewReader(data), protocol.VersionWhatever) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := ParseBlockedFrameLegacy(bytes.NewReader(data[0:i]), protocol.VersionWhatever) + Expect(err).To(HaveOccurred()) + } + }) + Context("when writing", func() { - Context("in big endian", func() { - It("writes a BLOCKED frame for a stream", func() { - b := &bytes.Buffer{} - frame := StreamBlockedFrame{StreamID: 0x1337} - frame.Write(b, versionBigEndian) - Expect(b.Bytes()).To(Equal([]byte{0x5, 0x0, 0x0, 0x13, 0x37})) - }) + It("writes a BLOCKED frame for a stream", func() { + b := &bytes.Buffer{} + frame := StreamBlockedFrame{StreamID: 0x1337} + frame.Write(b, versionBigEndian) + Expect(b.Bytes()).To(Equal([]byte{0x5, 0x0, 0x0, 0x13, 0x37})) + }) - It("has the correct min length for a BLOCKED frame for a stream", func() { - frame := StreamBlockedFrame{StreamID: 3} - Expect(frame.MinLength(0)).To(Equal(protocol.ByteCount(5))) - }) + It("has the correct min length for a BLOCKED frame for a stream", func() { + frame := StreamBlockedFrame{StreamID: 3} + Expect(frame.MinLength(versionBigEndian)).To(Equal(protocol.ByteCount(5))) + }) - It("writes a BLOCKED frame for the connection", func() { - b := &bytes.Buffer{} - frame := BlockedFrame{} - frame.Write(b, versionBigEndian) - Expect(b.Bytes()).To(Equal([]byte{0x5, 0x0, 0x0, 0x0, 0x0})) - }) + It("writes a BLOCKED frame for the connection", func() { + b := &bytes.Buffer{} + frame := BlockedFrame{} + frame.Write(b, versionBigEndian) + Expect(b.Bytes()).To(Equal([]byte{0x5, 0x0, 0x0, 0x0, 0x0})) + }) - It("has the correct min length for a BLOCKED frame for the connection", func() { - frame := BlockedFrame{} - Expect(frame.MinLength(versionBigEndian)).To(Equal(protocol.ByteCount(5))) - }) + It("has the correct min length for a BLOCKED frame for the connection", func() { + frame := BlockedFrame{} + Expect(frame.MinLength(versionBigEndian)).To(Equal(protocol.ByteCount(5))) }) }) }) diff --git a/internal/wire/connection_close_frame.go b/internal/wire/connection_close_frame.go index 432c6a86..05923c2f 100644 --- a/internal/wire/connection_close_frame.go +++ b/internal/wire/connection_close_frame.go @@ -19,23 +19,33 @@ type ConnectionCloseFrame struct { // ParseConnectionCloseFrame reads a CONNECTION_CLOSE frame func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) (*ConnectionCloseFrame, error) { - frame := &ConnectionCloseFrame{} - - // read the TypeByte - _, err := r.ReadByte() - if err != nil { + if _, err := r.ReadByte(); err != nil { // read the TypeByte return nil, err } - errorCode, err := utils.GetByteOrder(version).ReadUint32(r) - if err != nil { - return nil, err - } - frame.ErrorCode = qerr.ErrorCode(errorCode) - - reasonPhraseLen, err := utils.GetByteOrder(version).ReadUint16(r) - if err != nil { - return nil, err + var errorCode qerr.ErrorCode + var reasonPhraseLen uint64 + if version.UsesIETFFrameFormat() { + ec, err := utils.GetByteOrder(version).ReadUint16(r) + if err != nil { + return nil, err + } + errorCode = qerr.ErrorCode(ec) + reasonPhraseLen, err = utils.ReadVarInt(r) + if err != nil { + return nil, err + } + } else { + ec, err := utils.GetByteOrder(version).ReadUint32(r) + if err != nil { + return nil, err + } + errorCode = qerr.ErrorCode(ec) + length, err := utils.GetByteOrder(version).ReadUint16(r) + if err != nil { + return nil, err + } + reasonPhraseLen = uint64(length) } // shortcut to prevent the unneccessary allocation of dataLen bytes @@ -50,27 +60,36 @@ func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) // this should never happen, since we already checked the reasonPhraseLen earlier return nil, err } - frame.ReasonPhrase = string(reasonPhrase) - return frame, nil + return &ConnectionCloseFrame{ + ErrorCode: qerr.ErrorCode(errorCode), + ReasonPhrase: string(reasonPhrase), + }, nil } // MinLength of a written frame func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { + if version.UsesIETFFrameFormat() { + return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)), nil + } return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)), nil } // Write writes an CONNECTION_CLOSE frame. func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { b.WriteByte(0x02) - utils.GetByteOrder(version).WriteUint32(b, uint32(f.ErrorCode)) if len(f.ReasonPhrase) > math.MaxUint16 { return errors.New("ConnectionFrame: ReasonPhrase too long") } - reasonPhraseLen := uint16(len(f.ReasonPhrase)) - utils.GetByteOrder(version).WriteUint16(b, reasonPhraseLen) + if version.UsesIETFFrameFormat() { + utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode)) + utils.WriteVarInt(b, uint64(len(f.ReasonPhrase))) + } else { + utils.GetByteOrder(version).WriteUint32(b, uint32(f.ErrorCode)) + utils.GetByteOrder(version).WriteUint16(b, uint16(len(f.ReasonPhrase))) + } b.WriteString(f.ReasonPhrase) return nil diff --git a/internal/wire/connection_close_frame_test.go b/internal/wire/connection_close_frame_test.go index 150d4c45..b0bffa86 100644 --- a/internal/wire/connection_close_frame_test.go +++ b/internal/wire/connection_close_frame_test.go @@ -3,7 +3,6 @@ package wire import ( "bytes" "io" - "strings" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/qerr" @@ -11,8 +10,56 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("ConnectionCloseFrame", func() { +var _ = Describe("CONNECTION_CLOSE Frame", func() { Context("when parsing", func() { + Context("in varint encoding", func() { + It("accepts sample frame", func() { + data := []byte{0x2, 0x0, 0x19} + data = append(data, encodeVarInt(0x1b)...) // reason phrase length + data = append(data, []byte{ + 'N', 'o', ' ', 'r', 'e', 'c', 'e', 'n', 't', ' ', 'n', 'e', 't', 'w', 'o', 'r', 'k', ' ', 'a', 'c', 't', 'i', 'v', 'i', 't', 'y', '.', + }...) + b := bytes.NewReader(data) + frame, err := ParseConnectionCloseFrame(b, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.ErrorCode).To(Equal(qerr.ErrorCode(0x19))) + Expect(frame.ReasonPhrase).To(Equal("No recent network activity.")) + Expect(b.Len()).To(BeZero()) + }) + + It("rejects long reason phrases", func() { + data := []byte{0x2, 0xca, 0xfe} + data = append(data, encodeVarInt(0xffff)...) // reason phrase length + b := bytes.NewReader(data) + _, err := ParseConnectionCloseFrame(b, versionIETFFrames) + Expect(err).To(MatchError(io.EOF)) + }) + + It("errors on EOFs", func() { + data := []byte{0x2, 0x0, 0x19} + data = append(data, encodeVarInt(0x1b)...) // reason phrase length + data = append(data, []byte{ + 'N', 'o', ' ', 'r', 'e', 'c', 'e', 'n', 't', ' ', 'n', 'e', 't', 'w', 'o', 'r', 'k', ' ', 'a', 'c', 't', 'i', 'v', 'i', 't', 'y', '.', + }...) + _, err := ParseConnectionCloseFrame(bytes.NewReader(data), versionIETFFrames) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := ParseConnectionCloseFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + Expect(err).To(HaveOccurred()) + } + }) + + It("parses a frame without a reason phrase", func() { + data := []byte{0x2, 0xca, 0xfe} + data = append(data, encodeVarInt(0)...) + b := bytes.NewReader(data) + frame, err := ParseConnectionCloseFrame(b, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.ReasonPhrase).To(BeEmpty()) + Expect(b.Len()).To(BeZero()) + }) + }) + Context("in big endian", func() { It("accepts sample frame", func() { b := bytes.NewReader([]byte{0x2, @@ -49,21 +96,60 @@ var _ = Describe("ConnectionCloseFrame", func() { Expect(err).To(HaveOccurred()) } }) - }) - It("parses a frame without a reason phrase", func() { - b := bytes.NewReader([]byte{0x2, - 0xad, 0xfb, 0xca, 0xde, // error code - 0x0, 0x0, // reason phrase length + It("parses a frame without a reason phrase", func() { + b := bytes.NewReader([]byte{0x2, + 0xad, 0xfb, 0xca, 0xde, // error code + 0x0, 0x0, // reason phrase length + }) + frame, err := ParseConnectionCloseFrame(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.ReasonPhrase).To(BeEmpty()) + Expect(b.Len()).To(BeZero()) }) - frame, err := ParseConnectionCloseFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.ReasonPhrase).To(BeEmpty()) - Expect(b.Len()).To(BeZero()) }) }) Context("when writing", func() { + Context("in varint encoding", func() { + It("writes a frame without a ReasonPhrase", func() { + b := &bytes.Buffer{} + frame := &ConnectionCloseFrame{ + ErrorCode: 0xbeef, + } + err := frame.Write(b, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x2, 0xbe, 0xef} + expected = append(expected, encodeVarInt(0)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with a ReasonPhrase", func() { + b := &bytes.Buffer{} + frame := &ConnectionCloseFrame{ + ErrorCode: 0xdead, + ReasonPhrase: "foobar", + } + err := frame.Write(b, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x2, 0xde, 0xad} + expected = append(expected, encodeVarInt(6)...) + expected = append(expected, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has proper min length", func() { + b := &bytes.Buffer{} + f := &ConnectionCloseFrame{ + ErrorCode: 0xcafe, + ReasonPhrase: "foobar", + } + err := f.Write(b, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + Expect(f.MinLength(versionIETFFrames)).To(Equal(protocol.ByteCount(b.Len()))) + }) + }) + Context("in big endian", func() { It("writes a frame without a ReasonPhrase", func() { b := &bytes.Buffer{} @@ -94,44 +180,17 @@ var _ = Describe("ConnectionCloseFrame", func() { 'f', 'o', 'o', 'b', 'a', 'r', })) }) - }) - It("rejects ReasonPhrases that are too long", func() { - b := &bytes.Buffer{} - reasonPhrase := strings.Repeat("a", 0xffff+0x11) - frame := &ConnectionCloseFrame{ - ErrorCode: 0xdeadbeef, - ReasonPhrase: reasonPhrase, - } - err := frame.Write(b, protocol.VersionWhatever) - Expect(err).To(HaveOccurred()) + It("has proper min length", func() { + b := &bytes.Buffer{} + f := &ConnectionCloseFrame{ + ErrorCode: 0xcafe, + ReasonPhrase: "foobar", + } + err := f.Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(f.MinLength(versionBigEndian)).To(Equal(protocol.ByteCount(b.Len()))) + }) }) - - It("has proper min length", func() { - b := &bytes.Buffer{} - f := &ConnectionCloseFrame{ - ErrorCode: 0xdeadbeef, - ReasonPhrase: "foobar", - } - err := f.Write(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(f.MinLength(0)).To(Equal(protocol.ByteCount(b.Len()))) - }) - }) - - It("is self-consistent", func() { - buf := &bytes.Buffer{} - frame := &ConnectionCloseFrame{ - ErrorCode: 0xdeadbeef, - ReasonPhrase: "Lorem ipsum dolor sit amet.", - } - err := frame.Write(buf, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(buf.Bytes()) - readframe, err := ParseConnectionCloseFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(readframe.ErrorCode).To(Equal(frame.ErrorCode)) - Expect(readframe.ReasonPhrase).To(Equal(frame.ReasonPhrase)) - Expect(b.Len()).To(BeZero()) }) }) diff --git a/internal/wire/max_data_frame.go b/internal/wire/max_data_frame.go index 45f82b20..19585bc6 100644 --- a/internal/wire/max_data_frame.go +++ b/internal/wire/max_data_frame.go @@ -20,7 +20,7 @@ func ParseMaxDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxDat } frame := &MaxDataFrame{} - byteOffset, err := utils.GetByteOrder(version).ReadUint64(r) + byteOffset, err := utils.ReadVarInt(r) if err != nil { return nil, err } @@ -38,7 +38,7 @@ func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) er }).Write(b, version) } b.WriteByte(0x4) - utils.GetByteOrder(version).WriteUint64(b, uint64(f.ByteOffset)) + utils.WriteVarInt(b, uint64(f.ByteOffset)) return nil } @@ -47,5 +47,5 @@ func (f *MaxDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteC if !version.UsesIETFFrameFormat() { // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which is longer return 1 + 4 + 8, nil } - return 1 + 8, nil + return 1 + utils.VarIntLen(uint64(f.ByteOffset)), nil } diff --git a/internal/wire/max_data_frame_test.go b/internal/wire/max_data_frame_test.go index c4507660..1a098d8d 100644 --- a/internal/wire/max_data_frame_test.go +++ b/internal/wire/max_data_frame_test.go @@ -4,6 +4,7 @@ import ( "bytes" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -11,19 +12,18 @@ import ( var _ = Describe("MAX_DATA frame", func() { Context("when parsing", func() { It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0x4, - 0xde, 0xca, 0xfb, 0xad, 0x11, 0x22, 0x33, 0x44, // byte offset - }) + data := []byte{0x4} + data = append(data, encodeVarInt(0xdecafbad123456)...) // byte offset + b := bytes.NewReader(data) frame, err := ParseMaxDataFrame(b, versionBigEndian) Expect(err).ToNot(HaveOccurred()) - Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0xdecafbad11223344))) + Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0xdecafbad123456))) Expect(b.Len()).To(BeZero()) }) It("errors on EOFs", func() { - data := []byte{0x4, - 0x44, 0x33, 0x22, 0x11, 0xad, 0xfb, 0xca, 0xde, // byte offset - } + data := []byte{0x4} + data = append(data, encodeVarInt(0xdecafbad1234567)...) // byte offset _, err := ParseMaxDataFrame(bytes.NewReader(data), versionIETFFrames) Expect(err).NotTo(HaveOccurred()) for i := range data { @@ -38,19 +38,19 @@ var _ = Describe("MAX_DATA frame", func() { f := &MaxDataFrame{ ByteOffset: 0xdeadbeef, } - Expect(f.MinLength(versionIETFFrames)).To(Equal(protocol.ByteCount(1 + 8))) + Expect(f.MinLength(versionIETFFrames)).To(Equal(1 + utils.VarIntLen(0xdeadbeef))) }) It("writes a MAX_DATA frame", func() { b := &bytes.Buffer{} f := &MaxDataFrame{ - ByteOffset: 0xdeadbeefcafe1337, + ByteOffset: 0xdeadbeefcafe, } err := f.Write(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x4, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // byte offset - })) + expected := []byte{0x4} + expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) + Expect(b.Bytes()).To(Equal(expected)) }) }) }) diff --git a/internal/wire/max_stream_data_frame.go b/internal/wire/max_stream_data_frame.go index 2f959565..810dc92a 100644 --- a/internal/wire/max_stream_data_frame.go +++ b/internal/wire/max_stream_data_frame.go @@ -22,13 +22,13 @@ func ParseMaxStreamDataFrame(r *bytes.Reader, version protocol.VersionNumber) (* return nil, err } - sid, err := utils.GetByteOrder(version).ReadUint32(r) + sid, err := utils.ReadVarInt(r) if err != nil { return nil, err } frame.StreamID = protocol.StreamID(sid) - byteOffset, err := utils.GetByteOrder(version).ReadUint64(r) + byteOffset, err := utils.ReadVarInt(r) if err != nil { return nil, err } @@ -45,12 +45,16 @@ func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumb }).Write(b, version) } b.WriteByte(0x5) - utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) - utils.GetByteOrder(version).WriteUint64(b, uint64(f.ByteOffset)) + utils.WriteVarInt(b, uint64(f.StreamID)) + utils.WriteVarInt(b, uint64(f.ByteOffset)) return nil } // MinLength of a written frame func (f *MaxStreamDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1 + 4 + 8, nil + // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which has a different length + if !version.UsesIETFFrameFormat() { + return 1 + 4 + 8, nil + } + return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.ByteOffset)), nil } diff --git a/internal/wire/max_stream_data_frame_test.go b/internal/wire/max_stream_data_frame_test.go index 7940ab30..d9453908 100644 --- a/internal/wire/max_stream_data_frame_test.go +++ b/internal/wire/max_stream_data_frame_test.go @@ -4,6 +4,7 @@ import ( "bytes" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -11,22 +12,21 @@ import ( var _ = Describe("MAX_STREAM_DATA frame", func() { Context("parsing", func() { It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0x5, - 0xde, 0xad, 0xbe, 0xef, // stream id - 0xde, 0xca, 0xfb, 0xad, 0x11, 0x22, 0x33, 0x44, // byte offset - }) + data := []byte{0x5} + data = append(data, encodeVarInt(0xdeadbeef)...) // Stream ID + data = append(data, encodeVarInt(0x12345678)...) // Offset + b := bytes.NewReader(data) frame, err := ParseMaxStreamDataFrame(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) - Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0xdecafbad11223344))) + Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0x12345678))) Expect(b.Len()).To(BeZero()) }) It("errors on EOFs", func() { - data := []byte{0x5, - 0xef, 0xbe, 0xad, 0xde, // stream id - 0x44, 0x33, 0x22, 0x11, 0xad, 0xfb, 0xca, 0xde, // byte offset - } + data := []byte{0x5} + data = append(data, encodeVarInt(0xdeadbeef)...) // Stream ID + data = append(data, encodeVarInt(0x12345678)...) // Offset _, err := ParseMaxStreamDataFrame(bytes.NewReader(data), versionIETFFrames) Expect(err).NotTo(HaveOccurred()) for i := range data { @@ -42,21 +42,21 @@ var _ = Describe("MAX_STREAM_DATA frame", func() { StreamID: 0x1337, ByteOffset: 0xdeadbeef, } - Expect(f.MinLength(0)).To(Equal(protocol.ByteCount(13))) + Expect(f.MinLength(protocol.VersionWhatever)).To(Equal(1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.ByteOffset)))) }) It("writes a sample frame", func() { b := &bytes.Buffer{} f := &MaxStreamDataFrame{ StreamID: 0xdecafbad, - ByteOffset: 0xdeadbeefcafe1337, + ByteOffset: 0xdeadbeefcafe42, } + expected := []byte{0x5} + expected = append(expected, encodeVarInt(0xdecafbad)...) + expected = append(expected, encodeVarInt(0xdeadbeefcafe42)...) err := f.Write(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x5, - 0xde, 0xca, 0xfb, 0xad, // stream id - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // byte offset - })) + Expect(b.Bytes()).To(Equal(expected)) }) }) }) diff --git a/internal/wire/rst_stream_frame.go b/internal/wire/rst_stream_frame.go index 04086f83..7bfe1741 100644 --- a/internal/wire/rst_stream_frame.go +++ b/internal/wire/rst_stream_frame.go @@ -14,44 +14,75 @@ type RstStreamFrame struct { ByteOffset protocol.ByteCount } +// ParseRstStreamFrame parses a RST_STREAM frame +func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstStreamFrame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + + var streamID protocol.StreamID + var errorCode uint32 + var byteOffset protocol.ByteCount + if version.UsesIETFFrameFormat() { + sid, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + streamID = protocol.StreamID(sid) + ec, err := utils.GetByteOrder(version).ReadUint16(r) + if err != nil { + return nil, err + } + errorCode = uint32(ec) + bo, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + byteOffset = protocol.ByteCount(bo) + } else { + sid, err := utils.GetByteOrder(version).ReadUint32(r) + if err != nil { + return nil, err + } + streamID = protocol.StreamID(sid) + bo, err := utils.GetByteOrder(version).ReadUint64(r) + if err != nil { + return nil, err + } + byteOffset = protocol.ByteCount(bo) + ec, err := utils.GetByteOrder(version).ReadUint32(r) + if err != nil { + return nil, err + } + errorCode = uint32(ec) + } + + return &RstStreamFrame{ + StreamID: streamID, + ErrorCode: errorCode, + ByteOffset: byteOffset, + }, nil +} + //Write writes a RST_STREAM frame func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { b.WriteByte(0x01) - utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) - utils.GetByteOrder(version).WriteUint64(b, uint64(f.ByteOffset)) - utils.GetByteOrder(version).WriteUint32(b, f.ErrorCode) + if version.UsesIETFFrameFormat() { + utils.WriteVarInt(b, uint64(f.StreamID)) + utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode)) + utils.WriteVarInt(b, uint64(f.ByteOffset)) + } else { + utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) + utils.GetByteOrder(version).WriteUint64(b, uint64(f.ByteOffset)) + utils.GetByteOrder(version).WriteUint32(b, f.ErrorCode) + } return nil } // MinLength of a written frame func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { + if version.UsesIETFFrameFormat() { + return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 + utils.VarIntLen(uint64(f.ByteOffset)), nil + } return 1 + 4 + 8 + 4, nil } - -// ParseRstStreamFrame parses a RST_STREAM frame -func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstStreamFrame, error) { - frame := &RstStreamFrame{} - - // read the TypeByte - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - sid, err := utils.GetByteOrder(version).ReadUint32(r) - if err != nil { - return nil, err - } - frame.StreamID = protocol.StreamID(sid) - - byteOffset, err := utils.GetByteOrder(version).ReadUint64(r) - if err != nil { - return nil, err - } - frame.ByteOffset = protocol.ByteCount(byteOffset) - - frame.ErrorCode, err = utils.GetByteOrder(version).ReadUint32(r) - if err != nil { - return nil, err - } - return frame, nil -} diff --git a/internal/wire/rst_stream_frame_test.go b/internal/wire/rst_stream_frame_test.go index 5a615705..380c055f 100644 --- a/internal/wire/rst_stream_frame_test.go +++ b/internal/wire/rst_stream_frame_test.go @@ -4,12 +4,41 @@ import ( "bytes" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -var _ = Describe("RstStreamFrame", func() { +var _ = Describe("RST_STREAM frame", func() { Context("when parsing", func() { + Context("in varint encoding", func() { + It("accepts sample frame", func() { + data := []byte{0x1} + data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID + data = append(data, []byte{0x13, 0x37}...) // error code + data = append(data, encodeVarInt(0x987654321)...) // byte offset + b := bytes.NewReader(data) + frame, err := ParseRstStreamFrame(b, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) + Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0x987654321))) + Expect(frame.ErrorCode).To(Equal(uint32(0x1337))) + }) + + It("errors on EOFs", func() { + data := []byte{0x1} + data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID + data = append(data, []byte{0x13, 0x37}...) // error code + data = append(data, encodeVarInt(0x987654321)...) // byte offset + _, err := ParseRstStreamFrame(bytes.NewReader(data), versionIETFFrames) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := ParseRstStreamFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + Expect(err).To(HaveOccurred()) + } + }) + }) + Context("in big endian", func() { It("accepts sample frame", func() { b := bytes.NewReader([]byte{0x1, @@ -23,26 +52,54 @@ var _ = Describe("RstStreamFrame", func() { Expect(frame.ByteOffset).To(Equal(protocol.ByteCount(0x8877665544332211))) Expect(frame.ErrorCode).To(Equal(uint32(0x34123713))) }) - }) - It("errors on EOFs", func() { - data := []byte{0x1, - 0xef, 0xbe, 0xad, 0xde, 0x44, // stream id - 0x33, 0x22, 0x11, 0xad, 0xfb, 0xca, 0xde, 0x34, // byte offset - 0x12, 0x37, 0x13, // error code - } - _, err := ParseRstStreamFrame(bytes.NewReader(data), protocol.VersionWhatever) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := ParseRstStreamFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever) - Expect(err).To(HaveOccurred()) - } + It("errors on EOFs", func() { + data := []byte{0x1, + 0xef, 0xbe, 0xad, 0xde, 0x44, // stream id + 0x33, 0x22, 0x11, 0xad, 0xfb, 0xca, 0xde, 0x34, // byte offset + 0x12, 0x37, 0x13, // error code + } + _, err := ParseRstStreamFrame(bytes.NewReader(data), versionBigEndian) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := ParseRstStreamFrame(bytes.NewReader(data[0:i]), versionBigEndian) + Expect(err).To(HaveOccurred()) + } + }) }) }) Context("when writing", func() { + Context("in varint encoding", func() { + It("writes a sample frame", func() { + frame := RstStreamFrame{ + StreamID: 0x1337, + ByteOffset: 0x11223344decafbad, + ErrorCode: 0xcafe, + } + b := &bytes.Buffer{} + err := frame.Write(b, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x1} + expected = append(expected, encodeVarInt(0x1337)...) + expected = append(expected, []byte{0xca, 0xfe}...) + expected = append(expected, encodeVarInt(0x11223344decafbad)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + rst := RstStreamFrame{ + StreamID: 0x1337, + ByteOffset: 0x1234567, + ErrorCode: 0xde, + } + expectedLen := 1 + utils.VarIntLen(0x1337) + utils.VarIntLen(0x1234567) + 2 + Expect(rst.MinLength(versionIETFFrames)).To(Equal(expectedLen)) + }) + }) + Context("in big endian", func() { - It("writes a sample RstStreamFrame", func() { + It("writes a sample frame", func() { frame := RstStreamFrame{ StreamID: 0x1337, ByteOffset: 0x11223344decafbad, @@ -57,15 +114,15 @@ var _ = Describe("RstStreamFrame", func() { 0xde, 0xad, 0xbe, 0xef, // error code })) }) - }) - It("has the correct min length", func() { - rst := RstStreamFrame{ - StreamID: 0x1337, - ByteOffset: 0x1000, - ErrorCode: 0xde, - } - Expect(rst.MinLength(0)).To(Equal(protocol.ByteCount(17))) + It("has the correct min length", func() { + rst := RstStreamFrame{ + StreamID: 0x1337, + ByteOffset: 0x1000, + ErrorCode: 0xde, + } + Expect(rst.MinLength(versionBigEndian)).To(Equal(protocol.ByteCount(17))) + }) }) }) }) diff --git a/internal/wire/stream_blocked_frame.go b/internal/wire/stream_blocked_frame.go index b6415190..510de50d 100644 --- a/internal/wire/stream_blocked_frame.go +++ b/internal/wire/stream_blocked_frame.go @@ -14,18 +14,14 @@ type StreamBlockedFrame struct { // ParseStreamBlockedFrame parses a STREAM_BLOCKED frame func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamBlockedFrame, error) { - frame := &StreamBlockedFrame{} - - // read the TypeByte - if _, err := r.ReadByte(); err != nil { + if _, err := r.ReadByte(); err != nil { // read the TypeByte return nil, err } - sid, err := utils.GetByteOrder(version).ReadUint32(r) + sid, err := utils.ReadVarInt(r) if err != nil { return nil, err } - frame.StreamID = protocol.StreamID(sid) - return frame, nil + return &StreamBlockedFrame{StreamID: protocol.StreamID(sid)}, nil } // Write writes a STREAM_BLOCKED frame @@ -34,11 +30,14 @@ func (f *StreamBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumb return (&blockedFrameLegacy{StreamID: f.StreamID}).Write(b, version) } b.WriteByte(0x09) - utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) + utils.WriteVarInt(b, uint64(f.StreamID)) return nil } // MinLength of a written frame func (f *StreamBlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1 + 4, nil + if !version.UsesIETFFrameFormat() { + return 1 + 4, nil + } + return 1 + utils.VarIntLen(uint64(f.StreamID)), nil } diff --git a/internal/wire/stream_blocked_frame_test.go b/internal/wire/stream_blocked_frame_test.go index cc7d9b88..d31ce787 100644 --- a/internal/wire/stream_blocked_frame_test.go +++ b/internal/wire/stream_blocked_frame_test.go @@ -4,6 +4,7 @@ import ( "bytes" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -12,9 +13,9 @@ import ( var _ = Describe("STREAM_BLOCKED frame", func() { Context("parsing", func() { It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0x9, - 0xde, 0xad, 0xbe, 0xef, // stream id - }) + data := []byte{0x9} + data = append(data, encodeVarInt(0xdeadbeef)...) + b := bytes.NewReader(data) frame, err := ParseStreamBlockedFrame(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) @@ -22,9 +23,8 @@ var _ = Describe("STREAM_BLOCKED frame", func() { }) It("errors on EOFs", func() { - data := []byte{0x9, - 0xef, 0xbe, 0xad, 0xde, // stream id - } + data := []byte{0x9} + data = append(data, encodeVarInt(0xdeadbeef)...) _, err := ParseStreamBlockedFrame(bytes.NewReader(data), versionIETFFrames) Expect(err).NotTo(HaveOccurred()) for i := range data { @@ -39,7 +39,7 @@ var _ = Describe("STREAM_BLOCKED frame", func() { f := &StreamBlockedFrame{ StreamID: 0x1337, } - Expect(f.MinLength(0)).To(Equal(protocol.ByteCount(5))) + Expect(f.MinLength(0)).To(Equal(1 + utils.VarIntLen(0x1337))) }) It("writes a sample frame", func() { @@ -49,9 +49,9 @@ var _ = Describe("STREAM_BLOCKED frame", func() { } err := f.Write(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x9, - 0xde, 0xca, 0xfb, 0xad, // stream id - })) + expected := []byte{0x9} + expected = append(expected, encodeVarInt(uint64(f.StreamID))...) + Expect(b.Bytes()).To(Equal(expected)) }) }) }) diff --git a/internal/wire/window_update_frame.go b/internal/wire/window_update_frame.go index 20d7b66f..83d6adb0 100644 --- a/internal/wire/window_update_frame.go +++ b/internal/wire/window_update_frame.go @@ -17,14 +17,24 @@ type windowUpdateFrame struct { // * a MAX_STREAM_DATA frame, if the WINDOW_UPDATE applies to a stream // * a MAX_DATA frame, if the WINDOW_UPDATE applies to the connection func ParseWindowUpdateFrame(r *bytes.Reader, version protocol.VersionNumber) (Frame, error) { - f, err := ParseMaxStreamDataFrame(r, version) + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + streamID, err := utils.GetByteOrder(version).ReadUint32(r) if err != nil { return nil, err } - if f.StreamID == 0 { - return &MaxDataFrame{ByteOffset: f.ByteOffset}, nil + offset, err := utils.GetByteOrder(version).ReadUint64(r) + if err != nil { + return nil, err } - return f, nil + if streamID == 0 { + return &MaxDataFrame{ByteOffset: protocol.ByteCount(offset)}, nil + } + return &MaxStreamDataFrame{ + StreamID: protocol.StreamID(streamID), + ByteOffset: protocol.ByteCount(offset), + }, nil } func (f *windowUpdateFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { diff --git a/packet_unpacker.go b/packet_unpacker.go index 347743cb..5b8d90ae 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -125,7 +125,7 @@ func (u *packetUnpacker) parseIETFFrame(r *bytes.Reader, typeByte byte, hdr *wir case 0x8: frame, err = wire.ParseBlockedFrame(r, u.version) case 0x9: - frame, err = wire.ParseBlockedFrameLegacy(r, u.version) + frame, err = wire.ParseStreamBlockedFrame(r, u.version) if err != nil { err = qerr.Error(qerr.InvalidBlockedData, err.Error()) } diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index bc24f88a..c1832adb 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -294,8 +294,8 @@ var _ = Describe("Packet unpacker", func() { It("unpacks RST_STREAM frames", func() { f := &wire.RstStreamFrame{ StreamID: 0xdeadbeef, - ByteOffset: 0xdecafbad11223344, - ErrorCode: 0x13371234, + ByteOffset: 0xdecafbad1234, + ErrorCode: 0x1337, } err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) @@ -317,7 +317,7 @@ var _ = Describe("Packet unpacker", func() { It("unpacks MAX_DATA frames", func() { f := &wire.MaxDataFrame{ - ByteOffset: 0xcafe000000001337, + ByteOffset: 0xcafe, } buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) @@ -331,7 +331,7 @@ var _ = Describe("Packet unpacker", func() { It("unpacks MAX_STREAM_DATA frames", func() { f := &wire.MaxStreamDataFrame{ StreamID: 0xdeadbeef, - ByteOffset: 0xcafe000000001337, + ByteOffset: 0xdecafbad, } buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames)