From 9f1eb6286025ff6719260c3b7af0da428d81f2d6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 2 Nov 2017 15:33:13 +0700 Subject: [PATCH] implement the IETF draft BLOCKED and STREAM_BLOCKED frames --- internal/wire/blocked_frame.go | 40 +++++----- internal/wire/blocked_frame_legacy.go | 38 +++++++++ internal/wire/blocked_frame_legacy_test.go | 89 ++++++++++++++++++++++ internal/wire/blocked_frame_test.go | 64 ++++------------ internal/wire/stream_blocked_frame.go | 44 +++++++++++ internal/wire/stream_blocked_frame_test.go | 57 ++++++++++++++ packet_packer_test.go | 16 ++-- packet_unpacker.go | 9 ++- packet_unpacker_test.go | 44 +++++++++-- session.go | 1 + stream_framer.go | 8 +- stream_framer_test.go | 12 +-- 12 files changed, 326 insertions(+), 96 deletions(-) create mode 100644 internal/wire/blocked_frame_legacy.go create mode 100644 internal/wire/blocked_frame_legacy_test.go create mode 100644 internal/wire/stream_blocked_frame.go create mode 100644 internal/wire/stream_blocked_frame_test.go diff --git a/internal/wire/blocked_frame.go b/internal/wire/blocked_frame.go index 76e38f02..08dc051e 100644 --- a/internal/wire/blocked_frame.go +++ b/internal/wire/blocked_frame.go @@ -4,38 +4,32 @@ import ( "bytes" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" ) -// A BlockedFrame in QUIC -type BlockedFrame struct { - StreamID protocol.StreamID +// A BlockedFrame is a BLOCKED frame +type BlockedFrame struct{} + +// ParseBlockedFrame parses a BLOCKED frame +func ParseBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*BlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + return &BlockedFrame{}, nil } -//Write writes a BlockedFrame frame func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x05) - utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) + if !version.UsesMaxDataFrame() { + return (&blockedFrameLegacy{}).Write(b, version) + } + typeByte := uint8(0x08) + b.WriteByte(typeByte) return nil } // MinLength of a written frame func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1 + 4, nil -} - -// ParseBlockedFrame parses a BLOCKED frame -func ParseBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*BlockedFrame, error) { - frame := &BlockedFrame{} - - // read the TypeByte - if _, err := r.ReadByte(); err != nil { - return nil, err + if !version.UsesMaxDataFrame() { // writing this frame would result in a legacy BLOCKED being written, which is longer + return 1 + 4, nil } - sid, err := utils.GetByteOrder(version).ReadUint32(r) - if err != nil { - return nil, err - } - frame.StreamID = protocol.StreamID(sid) - return frame, nil + return 1, nil } diff --git a/internal/wire/blocked_frame_legacy.go b/internal/wire/blocked_frame_legacy.go new file mode 100644 index 00000000..d60ca4c4 --- /dev/null +++ b/internal/wire/blocked_frame_legacy.go @@ -0,0 +1,38 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type blockedFrameLegacy struct { + StreamID protocol.StreamID +} + +// ParseBlockedFrameLegacy parses a BLOCKED frame (in gQUIC format) +// The frame returned is +// * a STREAM_BLOCKED frame, if the BLOCKED applies to a stream +// * a BLOCKED frame, if the BLOCKED applies to the connection +func ParseBlockedFrameLegacy(r *bytes.Reader, version protocol.VersionNumber) (Frame, error) { + // read the TypeByte + if _, err := r.ReadByte(); err != nil { + return nil, err + } + streamID, err := utils.GetByteOrder(version).ReadUint32(r) + if err != nil { + return nil, err + } + if streamID == 0 { + return &BlockedFrame{}, nil + } + return &StreamBlockedFrame{StreamID: protocol.StreamID(streamID)}, nil +} + +//Write writes a BLOCKED frame +func (f *blockedFrameLegacy) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + b.WriteByte(0x05) + utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) + return nil +} diff --git a/internal/wire/blocked_frame_legacy_test.go b/internal/wire/blocked_frame_legacy_test.go new file mode 100644 index 00000000..1bddd14b --- /dev/null +++ b/internal/wire/blocked_frame_legacy_test.go @@ -0,0 +1,89 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("legacy BLOCKED Frame", func() { + Context("when parsing", func() { + Context("in little endian", func() { + It("accepts sample frame", func() { + b := bytes.NewReader([]byte{0x5, 0xef, 0xbe, 0xad, 0xde}) + f, err := ParseBlockedFrameLegacy(b, versionLittleEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeAssignableToTypeOf(&StreamBlockedFrame{})) + frame := f.(*StreamBlockedFrame) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) + }) + }) + + Context("in big endian", func() { + It("accepts sample frame for a stream", func() { + b := bytes.NewReader([]byte{0x5, 0xde, 0xad, 0xbe, 0xef}) + f, err := ParseBlockedFrameLegacy(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeAssignableToTypeOf(&StreamBlockedFrame{})) + frame := f.(*StreamBlockedFrame) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) + }) + + It("accepts sample frame for the connection", func() { + b := bytes.NewReader([]byte{0x5, 0x0, 0x0, 0x0, 0x0}) + f, err := ParseBlockedFrameLegacy(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeAssignableToTypeOf(&BlockedFrame{})) + }) + }) + + It("errors on EOFs", func() { + data := []byte{0x5, 0xef, 0xbe, 0xad, 0xde} + _, err := ParseBlockedFrameLegacy(bytes.NewReader(data), protocol.VersionWhatever) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := ParseBlockedFrameLegacy(bytes.NewReader(data[0:i]), protocol.VersionWhatever) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("when writing", func() { + Context("in little endian", func() { + It("writes a sample frame", func() { + b := &bytes.Buffer{} + frame := StreamBlockedFrame{StreamID: 0x1337} + frame.Write(b, versionLittleEndian) + Expect(b.Bytes()).To(Equal([]byte{0x5, 0x37, 0x13, 0x0, 0x0})) + }) + }) + + Context("in big endian", func() { + It("writes a BLOCKED frame for a stream", func() { + b := &bytes.Buffer{} + frame := StreamBlockedFrame{StreamID: 0x1337} + frame.Write(b, versionBigEndian) + Expect(b.Bytes()).To(Equal([]byte{0x5, 0x0, 0x0, 0x13, 0x37})) + }) + + It("has the correct min length for a BLOCKED frame for a stream", func() { + frame := StreamBlockedFrame{StreamID: 3} + Expect(frame.MinLength(0)).To(Equal(protocol.ByteCount(5))) + }) + + It("writes a BLOCKED frame for the connection", func() { + b := &bytes.Buffer{} + frame := BlockedFrame{} + frame.Write(b, versionBigEndian) + Expect(b.Bytes()).To(Equal([]byte{0x5, 0x0, 0x0, 0x0, 0x0})) + }) + + It("has the correct min length for a BLOCKED frame for the connection", func() { + frame := BlockedFrame{} + Expect(frame.MinLength(versionBigEndian)).To(Equal(protocol.ByteCount(5))) + }) + }) + }) +}) diff --git a/internal/wire/blocked_frame_test.go b/internal/wire/blocked_frame_test.go index ac7048a3..b48065b6 100644 --- a/internal/wire/blocked_frame_test.go +++ b/internal/wire/blocked_frame_test.go @@ -4,70 +4,38 @@ import ( "bytes" "github.com/lucas-clemente/quic-go/internal/protocol" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -var _ = Describe("BlockedFrame", func() { +var _ = Describe("BLOCKED frame", func() { Context("when parsing", func() { - Context("in little endian", func() { - It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0x5, 0xef, 0xbe, 0xad, 0xde}) - frame, err := ParseBlockedFrame(b, versionLittleEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) - }) - }) - - Context("in big endian", func() { - It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0x5, 0xde, 0xad, 0xbe, 0xef}) - frame, err := ParseBlockedFrame(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) - }) + It("accepts sample frame", func() { + b := bytes.NewReader([]byte{0x08}) + _, err := ParseBlockedFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeZero()) }) It("errors on EOFs", func() { - data := []byte{0x5, 0xef, 0xbe, 0xad, 0xde} - _, err := ParseBlockedFrame(bytes.NewReader(data), protocol.VersionWhatever) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := ParseBlockedFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever) - Expect(err).To(HaveOccurred()) - } + _, err := ParseBlockedFrame(bytes.NewReader(nil), protocol.VersionWhatever) + Expect(err).To(HaveOccurred()) }) }) Context("when writing", func() { - Context("in little endian", func() { - It("writes a sample frame", func() { - b := &bytes.Buffer{} - frame := BlockedFrame{StreamID: 0x1337} - frame.Write(b, versionLittleEndian) - Expect(b.Bytes()).To(Equal([]byte{0x5, 0x37, 0x13, 0x0, 0x0})) - }) - }) - - Context("in big endian", func() { - It("writes a sample frame", func() { - b := &bytes.Buffer{} - frame := BlockedFrame{StreamID: 0x1337} - frame.Write(b, versionBigEndian) - Expect(b.Bytes()).To(Equal([]byte{0x5, 0x0, 0x0, 0x13, 0x37})) - }) - }) - - It("writes a connection-level Blocked", func() { + It("writes a sample frame", func() { b := &bytes.Buffer{} - frame := BlockedFrame{StreamID: 0} - frame.Write(b, 0) - Expect(b.Bytes()).To(Equal([]byte{0x5, 0, 0, 0, 0})) + frame := BlockedFrame{} + err := frame.Write(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x08})) }) It("has the correct min length", func() { - frame := BlockedFrame{StreamID: 3} - Expect(frame.MinLength(0)).To(Equal(protocol.ByteCount(5))) + frame := BlockedFrame{} + Expect(frame.MinLength(versionMaxDataFrame)).To(Equal(protocol.ByteCount(1))) }) }) }) diff --git a/internal/wire/stream_blocked_frame.go b/internal/wire/stream_blocked_frame.go new file mode 100644 index 00000000..981c0ecc --- /dev/null +++ b/internal/wire/stream_blocked_frame.go @@ -0,0 +1,44 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A StreamBlockedFrame in QUIC +type StreamBlockedFrame struct { + StreamID protocol.StreamID +} + +// ParseStreamBlockedFrame parses a STREAM_BLOCKED frame +func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamBlockedFrame, error) { + frame := &StreamBlockedFrame{} + + // read the TypeByte + if _, err := r.ReadByte(); err != nil { + return nil, err + } + sid, err := utils.GetByteOrder(version).ReadUint32(r) + if err != nil { + return nil, err + } + frame.StreamID = protocol.StreamID(sid) + return frame, nil +} + +// Write writes a STREAM_BLOCKED frame +func (f *StreamBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if !version.UsesMaxDataFrame() { + return (&blockedFrameLegacy{StreamID: f.StreamID}).Write(b, version) + } + b.WriteByte(0x09) + utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) + return nil +} + +// MinLength of a written frame +func (f *StreamBlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { + return 1 + 4, nil +} diff --git a/internal/wire/stream_blocked_frame_test.go b/internal/wire/stream_blocked_frame_test.go new file mode 100644 index 00000000..002e8889 --- /dev/null +++ b/internal/wire/stream_blocked_frame_test.go @@ -0,0 +1,57 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STREAM_BLOCKED frame", func() { + Context("parsing", func() { + It("accepts sample frame", func() { + b := bytes.NewReader([]byte{0x9, + 0xde, 0xad, 0xbe, 0xef, // stream id + }) + frame, err := ParseStreamBlockedFrame(b, versionMaxDataFrame) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x9, + 0xef, 0xbe, 0xad, 0xde, // stream id + } + _, err := ParseStreamBlockedFrame(bytes.NewReader(data), versionMaxDataFrame) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := ParseStreamBlockedFrame(bytes.NewReader(data[0:i]), versionMaxDataFrame) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("writing", func() { + It("has proper min length", func() { + f := &StreamBlockedFrame{ + StreamID: 0x1337, + } + Expect(f.MinLength(0)).To(Equal(protocol.ByteCount(5))) + }) + + It("writes a sample frame", func() { + b := &bytes.Buffer{} + f := &StreamBlockedFrame{ + StreamID: 0xdecafbad, + } + err := f.Write(b, versionMaxDataFrame) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x9, + 0xde, 0xca, 0xfb, 0xad, // stream id + })) + }) + }) +}) diff --git a/packet_packer_test.go b/packet_packer_test.go index 100f6c81..76e1e6ed 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -301,7 +301,7 @@ var _ = Describe("Packet packer", func() { }) It("packs a packet if it has queued control frames, but no new control frames", func() { - packer.controlFrames = []wire.Frame{&wire.BlockedFrame{StreamID: 0}} + packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}} p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -326,9 +326,7 @@ var _ = Describe("Packet packer", func() { }) It("packs a lot of control frames into 2 packets if they don't fit into one", func() { - blockedFrame := &wire.BlockedFrame{ - StreamID: 0x1337, - } + blockedFrame := &wire.BlockedFrame{} minLength, _ := blockedFrame.MinLength(0) maxFramesPerPacket := int(maxFrameSize) / int(minLength) var controlFrames []wire.Frame @@ -607,7 +605,7 @@ var _ = Describe("Packet packer", func() { Context("Blocked frames", func() { It("queues a BLOCKED frame", func() { length := 100 - streamFramer.blockedFrameQueue = []*wire.BlockedFrame{{StreamID: 5}} + streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}} f := &wire.StreamFrame{ StreamID: 5, Data: bytes.Repeat([]byte{'f'}, length), @@ -615,12 +613,12 @@ var _ = Describe("Packet packer", func() { streamFramer.AddFrameForRetransmission(f) _, err := packer.composeNextPacket(maxFrameSize, true) Expect(err).ToNot(HaveOccurred()) - Expect(packer.controlFrames[0]).To(Equal(&wire.BlockedFrame{StreamID: 5})) + Expect(packer.controlFrames[0]).To(Equal(&wire.StreamBlockedFrame{StreamID: 5})) }) It("removes the dataLen attribute from the last StreamFrame, even if it queued a BLOCKED frame", func() { length := 100 - streamFramer.blockedFrameQueue = []*wire.BlockedFrame{{StreamID: 5}} + streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}} f := &wire.StreamFrame{ StreamID: 5, Data: bytes.Repeat([]byte{'f'}, length), @@ -633,7 +631,7 @@ var _ = Describe("Packet packer", func() { }) It("packs a connection-level BlockedFrame", func() { - streamFramer.blockedFrameQueue = []*wire.BlockedFrame{{StreamID: 0}} + streamFramer.blockedFrameQueue = []wire.Frame{&wire.BlockedFrame{}} f := &wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), @@ -641,7 +639,7 @@ var _ = Describe("Packet packer", func() { streamFramer.AddFrameForRetransmission(f) _, err := packer.composeNextPacket(maxFrameSize, true) Expect(err).ToNot(HaveOccurred()) - Expect(packer.controlFrames[0]).To(Equal(&wire.BlockedFrame{StreamID: 0})) + Expect(packer.controlFrames[0]).To(Equal(&wire.BlockedFrame{})) }) }) diff --git a/packet_unpacker.go b/packet_unpacker.go index 638ab1a3..f891e37f 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -94,7 +94,7 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []by err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) } } else if typeByte == 0x5 { // in gQUIC, 0x5 is a BLOCKED frame - frame, err = wire.ParseBlockedFrame(r, u.version) + frame, err = wire.ParseBlockedFrameLegacy(r, u.version) if err != nil { err = qerr.Error(qerr.InvalidBlockedData, err.Error()) } @@ -105,6 +105,13 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []by } } else if typeByte == 0x7 { frame, err = wire.ParsePingFrame(r, u.version) + } else if u.version.UsesMaxDataFrame() && typeByte == 0x8 { // in IETF QUIC, 0x4 is a BLOCKED frame + frame, err = wire.ParseBlockedFrame(r, u.version) + } else if u.version.UsesMaxDataFrame() && typeByte == 0x9 { // in IETF QUIC, 0x4 is a STREAM_BLOCKED frame + frame, err = wire.ParseBlockedFrameLegacy(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidBlockedData, err.Error()) + } } else { err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte)) } diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 3e56dd17..d2a76393 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -183,10 +183,33 @@ var _ = Describe("Packet unpacker", func() { Expect(packet.frames).To(Equal([]wire.Frame{f})) }) + It("unpacks connection-level BLOCKED frames", func() { + f := &wire.BlockedFrame{} + buf := &bytes.Buffer{} + err := f.Write(buf, versionCryptoStream0) + Expect(err).ToNot(HaveOccurred()) + setData(buf.Bytes()) + packet, err := unpacker.Unpack(hdrBin, hdr, data) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.frames).To(Equal([]wire.Frame{f})) + }) + + It("unpacks stream-level BLOCKED frames", func() { + f := &wire.StreamBlockedFrame{StreamID: 0xdeadbeef} + buf := &bytes.Buffer{} + err := f.Write(buf, versionCryptoStream0) + Expect(err).ToNot(HaveOccurred()) + setData(buf.Bytes()) + packet, err := unpacker.Unpack(hdrBin, hdr, data) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.frames).To(Equal([]wire.Frame{f})) + }) + It("errors on invalid frames", func() { for b, e := range map[byte]qerr.ErrorCode{ 0x04: qerr.InvalidWindowUpdateData, 0x05: qerr.InvalidWindowUpdateData, + 0x09: qerr.InvalidBlockedData, } { setData([]byte{b}) _, err := unpacker.Unpack(hdrBin, hdr, data) @@ -227,10 +250,21 @@ var _ = Describe("Packet unpacker", func() { Expect(packet.frames).To(Equal([]wire.Frame{f})) }) - It("unpacks BLOCKED frames", func() { - f := &wire.BlockedFrame{StreamID: 0xDEADBEEF} + It("unpacks connection-level BLOCKED frames", func() { + f := &wire.BlockedFrame{} buf := &bytes.Buffer{} - err := f.Write(buf, protocol.VersionWhatever) + err := f.Write(buf, versionCryptoStream1) + Expect(err).ToNot(HaveOccurred()) + setData(buf.Bytes()) + packet, err := unpacker.Unpack(hdrBin, hdr, data) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.frames).To(Equal([]wire.Frame{f})) + }) + + It("unpacks stream-level BLOCKED frames", func() { + f := &wire.StreamBlockedFrame{StreamID: 0xdeadbeef} + buf := &bytes.Buffer{} + err := f.Write(buf, versionCryptoStream1) Expect(err).ToNot(HaveOccurred()) setData(buf.Bytes()) packet, err := unpacker.Unpack(hdrBin, hdr, data) @@ -269,9 +303,9 @@ var _ = Describe("Packet unpacker", func() { }) It("errors on invalid type", func() { - setData([]byte{0x08}) + setData([]byte{0xf}) _, err := unpacker.Unpack(hdrBin, hdr, data) - Expect(err).To(MatchError("InvalidFrameData: unknown type byte 0x8")) + Expect(err).To(MatchError("InvalidFrameData: unknown type byte 0xf")) }) It("errors on invalid frames", func() { diff --git a/session.go b/session.go index b4586f08..cf1ac88c 100644 --- a/session.go +++ b/session.go @@ -496,6 +496,7 @@ func (s *session) handleFrames(fs []wire.Frame) error { case *wire.MaxStreamDataFrame: err = s.handleMaxStreamDataFrame(frame) case *wire.BlockedFrame: + case *wire.StreamBlockedFrame: case *wire.PingFrame: default: return errors.New("Session BUG: unexpected frame type") diff --git a/stream_framer.go b/stream_framer.go index 00755b58..8928e490 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -13,7 +13,7 @@ type streamFramer struct { connFlowController flowcontrol.ConnectionFlowController retransmissionQueue []*wire.StreamFrame - blockedFrameQueue []*wire.BlockedFrame + blockedFrameQueue []wire.Frame } func newStreamFramer( @@ -37,7 +37,7 @@ func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*wire.Stream return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...) } -func (f *streamFramer) PopBlockedFrame() *wire.BlockedFrame { +func (f *streamFramer) PopBlockedFrame() wire.Frame { if len(f.blockedFrameQueue) == 0 { return nil } @@ -132,10 +132,10 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] // Finally, check if we are now FC blocked and should queue a BLOCKED frame if !frame.FinBit && s.IsFlowControlBlocked() { - f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.BlockedFrame{StreamID: s.StreamID()}) + f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.StreamBlockedFrame{StreamID: s.StreamID()}) } if f.connFlowController.IsBlocked() { - f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.BlockedFrame{StreamID: 0}) + f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.BlockedFrame{}) } res = append(res, frame) diff --git a/stream_framer_test.go b/stream_framer_test.go index 054147cc..e2d5e6a2 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -326,9 +326,10 @@ var _ = Describe("Stream Framer", func() { setNoData(stream2) frames := framer.PopStreamFrames(1000) Expect(frames).To(HaveLen(1)) - blockedFrame := framer.PopBlockedFrame() - Expect(blockedFrame).ToNot(BeNil()) - Expect(blockedFrame.StreamID).To(Equal(stream1.StreamID())) + f := framer.PopBlockedFrame() + Expect(f).To(BeAssignableToTypeOf(&wire.StreamBlockedFrame{})) + bf := f.(*wire.StreamBlockedFrame) + Expect(bf.StreamID).To(Equal(stream1.StreamID())) Expect(framer.PopBlockedFrame()).To(BeNil()) }) @@ -357,9 +358,8 @@ var _ = Describe("Stream Framer", func() { stream1.EXPECT().IsFlowControlBlocked().Return(false) setNoData(stream2) framer.PopStreamFrames(1000) - blockedFrame := framer.PopBlockedFrame() - Expect(blockedFrame).ToNot(BeNil()) - Expect(blockedFrame.StreamID).To(BeZero()) + f := framer.PopBlockedFrame() + Expect(f).To(BeAssignableToTypeOf(&wire.BlockedFrame{})) Expect(framer.PopBlockedFrame()).To(BeNil()) }) })