diff --git a/internal/wire/blocked_frame.go b/internal/wire/blocked_frame.go index 72c8a0561..04dd29d88 100644 --- a/internal/wire/blocked_frame.go +++ b/internal/wire/blocked_frame.go @@ -4,17 +4,26 @@ import ( "bytes" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" ) // A BlockedFrame is a BLOCKED frame -type BlockedFrame struct{} +type BlockedFrame struct { + Offset protocol.ByteCount +} // ParseBlockedFrame parses a BLOCKED frame -func ParseBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*BlockedFrame, error) { +func ParseBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*BlockedFrame, error) { if _, err := r.ReadByte(); err != nil { return nil, err } - return &BlockedFrame{}, nil + offset, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &BlockedFrame{ + Offset: protocol.ByteCount(offset), + }, nil } func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { @@ -23,13 +32,14 @@ func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) er } typeByte := uint8(0x08) b.WriteByte(typeByte) + utils.WriteVarInt(b, uint64(f.Offset)) return nil } // MinLength of a written frame func (f *BlockedFrame) MinLength(version protocol.VersionNumber) protocol.ByteCount { - if !version.UsesIETFFrameFormat() { // writing this frame would result in a legacy BLOCKED being written, which is longer + if !version.UsesIETFFrameFormat() { return 1 + 4 } - return 1 + return 1 + utils.VarIntLen(uint64(f.Offset)) } diff --git a/internal/wire/blocked_frame_test.go b/internal/wire/blocked_frame_test.go index 9a3e2dded..ce58820b5 100644 --- a/internal/wire/blocked_frame_test.go +++ b/internal/wire/blocked_frame_test.go @@ -2,8 +2,10 @@ package wire import ( "bytes" + "io" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -12,30 +14,41 @@ import ( var _ = Describe("BLOCKED frame", func() { Context("when parsing", func() { It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0x08}) - _, err := ParseBlockedFrame(b, protocol.VersionWhatever) + data := []byte{0x08} + data = append(data, encodeVarInt(0x12345678)...) + b := bytes.NewReader(data) + frame, err := ParseBlockedFrame(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) + Expect(frame.Offset).To(Equal(protocol.ByteCount(0x12345678))) Expect(b.Len()).To(BeZero()) }) It("errors on EOFs", func() { - _, err := ParseBlockedFrame(bytes.NewReader(nil), protocol.VersionWhatever) - Expect(err).To(HaveOccurred()) + data := []byte{0x08} + data = append(data, encodeVarInt(0x12345678)...) + _, err := ParseBlockedFrame(bytes.NewReader(data), versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + for i := range data { + _, err := ParseBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames) + Expect(err).To(MatchError(io.EOF)) + } }) }) Context("when writing", func() { It("writes a sample frame", func() { b := &bytes.Buffer{} - frame := BlockedFrame{} + frame := BlockedFrame{Offset: 0xdeadbeef} err := frame.Write(b, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x08})) + expected := []byte{0x08} + expected = append(expected, encodeVarInt(0xdeadbeef)...) + Expect(b.Bytes()).To(Equal(expected)) }) It("has the correct min length", func() { - frame := BlockedFrame{} - Expect(frame.MinLength(versionIETFFrames)).To(Equal(protocol.ByteCount(1))) + frame := BlockedFrame{Offset: 0x12345} + Expect(frame.MinLength(versionIETFFrames)).To(Equal(1 + utils.VarIntLen(0x12345))) }) }) }) diff --git a/internal/wire/stream_blocked_frame.go b/internal/wire/stream_blocked_frame.go index b2ecd5845..b67bd24e4 100644 --- a/internal/wire/stream_blocked_frame.go +++ b/internal/wire/stream_blocked_frame.go @@ -10,10 +10,11 @@ import ( // A StreamBlockedFrame in QUIC type StreamBlockedFrame struct { StreamID protocol.StreamID + Offset protocol.ByteCount } // ParseStreamBlockedFrame parses a STREAM_BLOCKED frame -func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamBlockedFrame, error) { +func ParseStreamBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamBlockedFrame, error) { if _, err := r.ReadByte(); err != nil { // read the TypeByte return nil, err } @@ -21,7 +22,14 @@ func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (* if err != nil { return nil, err } - return &StreamBlockedFrame{StreamID: protocol.StreamID(sid)}, nil + offset, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &StreamBlockedFrame{ + StreamID: protocol.StreamID(sid), + Offset: protocol.ByteCount(offset), + }, nil } // Write writes a STREAM_BLOCKED frame @@ -31,6 +39,7 @@ func (f *StreamBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumb } b.WriteByte(0x09) utils.WriteVarInt(b, uint64(f.StreamID)) + utils.WriteVarInt(b, uint64(f.Offset)) return nil } @@ -39,5 +48,5 @@ func (f *StreamBlockedFrame) MinLength(version protocol.VersionNumber) protocol. if !version.UsesIETFFrameFormat() { return 1 + 4 } - return 1 + utils.VarIntLen(uint64(f.StreamID)) + return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.Offset)) } diff --git a/internal/wire/stream_blocked_frame_test.go b/internal/wire/stream_blocked_frame_test.go index d31ce787b..42b2046ae 100644 --- a/internal/wire/stream_blocked_frame_test.go +++ b/internal/wire/stream_blocked_frame_test.go @@ -14,17 +14,20 @@ var _ = Describe("STREAM_BLOCKED frame", func() { Context("parsing", func() { It("accepts sample frame", func() { data := []byte{0x9} - data = append(data, encodeVarInt(0xdeadbeef)...) + data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID + data = append(data, encodeVarInt(0xdecafbad)...) // offset b := bytes.NewReader(data) frame, err := ParseStreamBlockedFrame(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) + Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) Expect(b.Len()).To(BeZero()) }) It("errors on EOFs", func() { data := []byte{0x9} data = append(data, encodeVarInt(0xdeadbeef)...) + data = append(data, encodeVarInt(0xc0010ff)...) _, err := ParseStreamBlockedFrame(bytes.NewReader(data), versionIETFFrames) Expect(err).NotTo(HaveOccurred()) for i := range data { @@ -38,19 +41,22 @@ var _ = Describe("STREAM_BLOCKED frame", func() { It("has proper min length", func() { f := &StreamBlockedFrame{ StreamID: 0x1337, + Offset: 0xdeadbeef, } - Expect(f.MinLength(0)).To(Equal(1 + utils.VarIntLen(0x1337))) + Expect(f.MinLength(0)).To(Equal(1 + utils.VarIntLen(0x1337) + utils.VarIntLen(0xdeadbeef))) }) It("writes a sample frame", func() { b := &bytes.Buffer{} f := &StreamBlockedFrame{ StreamID: 0xdecafbad, + Offset: 0x1337, } err := f.Write(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x9} expected = append(expected, encodeVarInt(uint64(f.StreamID))...) + expected = append(expected, encodeVarInt(uint64(f.Offset))...) Expect(b.Bytes()).To(Equal(expected)) }) }) diff --git a/packet_unpacker.go b/packet_unpacker.go index 9b11085f4..9d09c373f 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -117,6 +117,9 @@ func (u *packetUnpacker) parseIETFFrame(r *bytes.Reader, typeByte byte, hdr *wir frame, err = wire.ParsePingFrame(r, u.version) case 0x8: frame, err = wire.ParseBlockedFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidBlockedData, err.Error()) + } case 0x9: frame, err = wire.ParseStreamBlockedFrame(r, u.version) if err != nil { diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 228169481..8934bb7eb 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -343,7 +343,7 @@ var _ = Describe("Packet unpacker", func() { }) It("unpacks connection-level BLOCKED frames", func() { - f := &wire.BlockedFrame{} + f := &wire.BlockedFrame{Offset: 0x1234} buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) @@ -354,7 +354,10 @@ var _ = Describe("Packet unpacker", func() { }) It("unpacks stream-level BLOCKED frames", func() { - f := &wire.StreamBlockedFrame{StreamID: 0xdeadbeef} + f := &wire.StreamBlockedFrame{ + StreamID: 0xdeadbeef, + Offset: 0xdead, + } buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) @@ -403,6 +406,7 @@ var _ = Describe("Packet unpacker", func() { 0x02: qerr.InvalidConnectionCloseData, 0x04: qerr.InvalidWindowUpdateData, 0x05: qerr.InvalidWindowUpdateData, + 0x08: qerr.InvalidBlockedData, 0x09: qerr.InvalidBlockedData, 0x0c: qerr.InvalidFrameData, 0x0e: qerr.InvalidAckData,