From 8de22b2468d250df4bbe36ef3dead0ba847f36c8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 3 May 2024 13:14:18 +0200 Subject: [PATCH] http3: allow io.EOF when parsing a capsule fails on the first byte (#4476) --- http3/capsule.go | 20 ++++++++++++++++++-- http3/capsule_test.go | 8 ++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/http3/capsule.go b/http3/capsule.go index 7bdcd4e5..69d4037a 100644 --- a/http3/capsule.go +++ b/http3/capsule.go @@ -21,13 +21,29 @@ func (r *exactReader) Read(b []byte) (int, error) { return n, err } +type countingByteReader struct { + io.ByteReader + Read int +} + +func (r *countingByteReader) ReadByte() (byte, error) { + b, err := r.ByteReader.ReadByte() + if err == nil { + r.Read++ + } + return b, err +} + // ParseCapsule parses the header of a Capsule. // It returns an io.LimitedReader that can be used to read the Capsule value. // The Capsule value must be read entirely (i.e. until the io.EOF) before using r again. func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) { - ct, err := quicvarint.Read(r) + cbr := countingByteReader{ByteReader: r} + ct, err := quicvarint.Read(&cbr) if err != nil { - if err == io.EOF { + // If an io.EOF is returned without consuming any bytes, return it unmodified. + // Otherwise, return an io.ErrUnexpectedEOF. + if err == io.EOF && cbr.Read > 0 { return 0, nil, io.ErrUnexpectedEOF } return 0, nil, err diff --git a/http3/capsule_test.go b/http3/capsule_test.go index 4920e887..3e7e25e6 100644 --- a/http3/capsule_test.go +++ b/http3/capsule_test.go @@ -26,7 +26,7 @@ var _ = Describe("Capsule", func() { It("writes capsules", func() { var buf bytes.Buffer - WriteCapsule(&buf, 1337, []byte("foobar")) + Expect(WriteCapsule(&buf, 1337, []byte("foobar"))).To(Succeed()) ct, r, err := ParseCapsule(&buf) Expect(err).ToNot(HaveOccurred()) @@ -44,7 +44,11 @@ var _ = Describe("Capsule", func() { for i := range b { ct, r, err := ParseCapsule(bytes.NewReader(b[:i])) if err != nil { - Expect(err).To(MatchError(io.ErrUnexpectedEOF)) + if i == 0 { + Expect(err).To(MatchError(io.EOF)) + } else { + Expect(err).To(MatchError(io.ErrUnexpectedEOF)) + } continue } Expect(ct).To(BeEquivalentTo(1337))