diff --git a/internal/wire/header_parser.go b/internal/wire/header_parser.go index 15e1643f..348e300e 100644 --- a/internal/wire/header_parser.go +++ b/internal/wire/header_parser.go @@ -18,10 +18,21 @@ type InvariantHeader struct { DestConnectionID protocol.ConnectionID typeByte byte + len int // how many bytes were read while parsing this header } // ParseInvariantHeader parses the version independent part of the header func ParseInvariantHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*InvariantHeader, error) { + startLen := b.Len() + h, err := parseInvariantHeaderImpl(b, shortHeaderConnIDLen) + if err != nil { + return nil, err + } + h.len = startLen - b.Len() + return h, nil +} + +func parseInvariantHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*InvariantHeader, error) { typeByte, err := b.ReadByte() if err != nil { return nil, err @@ -61,8 +72,12 @@ func ParseInvariantHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Invariant return h, nil } -// Parse parses the version dependent part of the header +// Parse parses the version dependent part of the header. +// The Reader has to be set such that it points to the first byte of the header. func (iv *InvariantHeader) Parse(b *bytes.Reader, ver protocol.VersionNumber) (*Header, error) { + if _, err := b.Seek(int64(iv.len), io.SeekCurrent); err != nil { + return nil, err + } if iv.IsLongHeader { if iv.Version == 0 { // Version Negotiation Packet return iv.parseVersionNegotiationPacket(b) diff --git a/internal/wire/header_parser_test.go b/internal/wire/header_parser_test.go index be38d5be..6480a446 100644 --- a/internal/wire/header_parser_test.go +++ b/internal/wire/header_parser_test.go @@ -25,12 +25,12 @@ var _ = Describe("Header Parsing", func() { versions := []protocol.VersionNumber{0x22334455, 0x33445566} data, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b, 0) + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.DestConnectionID).To(Equal(destConnID)) Expect(iHdr.SrcConnectionID).To(Equal(srcConnID)) Expect(iHdr.IsLongHeader).To(BeTrue()) + b := bytes.NewReader(data) hdr, err := iHdr.Parse(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsVersionNegotiation).To(BeTrue()) @@ -40,6 +40,7 @@ var _ = Describe("Header Parsing", func() { for _, v := range versions { Expect(hdr.SupportedVersions).To(ContainElement(v)) } + Expect(b.Len()).To(BeZero()) }) It("errors if it contains versions of the wrong length", func() { @@ -47,10 +48,10 @@ var _ = Describe("Header Parsing", func() { versions := []protocol.VersionNumber{0x22334455, 0x33445566} data, err := ComposeVersionNegotiation(connID, connID, versions) Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data[:len(data)-2]) - iHdr, err := ParseInvariantHeader(b, 0) + data = data[:len(data)-2] + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) - _, err = iHdr.Parse(b, versionIETFFrames) + _, err = iHdr.Parse(bytes.NewReader(data), versionIETFFrames) Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket)) }) @@ -60,10 +61,10 @@ var _ = Describe("Header Parsing", func() { data, err := ComposeVersionNegotiation(connID, connID, versions) Expect(err).ToNot(HaveOccurred()) // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number - b := bytes.NewReader(data[:len(data)-8]) - iHdr, err := ParseInvariantHeader(b, 0) + data = data[:len(data)-8] + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) - _, err = iHdr.Parse(b, versionIETFFrames) + _, err = iHdr.Parse(bytes.NewReader(data), versionIETFFrames) Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list")) }) }) @@ -85,12 +86,12 @@ var _ = Describe("Header Parsing", func() { // packet number data = appendPacketNumber(data, 0xbeef, protocol.PacketNumberLen4) - b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b, 0) + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.IsLongHeader).To(BeTrue()) Expect(iHdr.DestConnectionID).To(Equal(destConnID)) Expect(iHdr.SrcConnectionID).To(Equal(srcConnID)) + b := bytes.NewReader(data) hdr, err := iHdr.Parse(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) @@ -115,8 +116,7 @@ var _ = Describe("Header Parsing", func() { } data = append(data, encodeVarInt(0x42)...) // length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) - b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b, 0) + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) Expect(iHdr.DestConnectionID).To(BeEmpty()) @@ -131,8 +131,7 @@ var _ = Describe("Header Parsing", func() { } data = append(data, encodeVarInt(0x42)...) // length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) - b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b, 0) + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.SrcConnectionID).To(BeEmpty()) Expect(iHdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) @@ -148,13 +147,14 @@ var _ = Describe("Header Parsing", func() { data = append(data, encodeVarInt(0x42)...) // length data = appendPacketNumber(data, 0x123, protocol.PacketNumberLen2) - b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b, 0) + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) hdr, err := iHdr.Parse(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x123))) Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) + Expect(b.Len()).To(BeZero()) }) It("parses a Retry packet", func() { @@ -166,14 +166,15 @@ var _ = Describe("Header Parsing", func() { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // source connection ID 'f', 'o', 'o', 'b', 'a', 'r', // token } - b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b, 0) + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) hdr, err := iHdr.Parse(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(hdr.OrigDestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) Expect(hdr.Token).To(Equal([]byte("foobar"))) + Expect(b.Len()).To(BeZero()) }) It("rejects packets sent with an unknown packet type", func() { @@ -269,11 +270,11 @@ var _ = Describe("Header Parsing", func() { connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} data := append([]byte{0x30}, connID...) data = appendPacketNumber(data, 0x42, protocol.PacketNumberLen1) - b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b, 8) + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 8) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.IsLongHeader).To(BeFalse()) Expect(iHdr.DestConnectionID).To(Equal(connID)) + b := bytes.NewReader(data) hdr, err := iHdr.Parse(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(hdr.KeyPhase).To(Equal(0)) @@ -288,11 +289,11 @@ var _ = Describe("Header Parsing", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5} data := append([]byte{0x30}, connID...) data = appendPacketNumber(data, 0x42, protocol.PacketNumberLen1) - b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b, 5) + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 5) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.IsLongHeader).To(BeFalse()) Expect(iHdr.DestConnectionID).To(Equal(connID)) + b := bytes.NewReader(data) hdr, err := iHdr.Parse(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(hdr.KeyPhase).To(Equal(0)) @@ -307,9 +308,9 @@ var _ = Describe("Header Parsing", func() { 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID } data = appendPacketNumber(data, 11, protocol.PacketNumberLen1) - b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b, 6) + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 6) Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) hdr, err := iHdr.Parse(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsLongHeader).To(BeFalse()) @@ -323,9 +324,9 @@ var _ = Describe("Header Parsing", func() { 0xde, 0xad, 0xbe, 0xef, // connection ID } data = appendPacketNumber(data, 0x1337, protocol.PacketNumberLen2) - b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b, 4) + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 4) Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) hdr, err := iHdr.Parse(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsLongHeader).To(BeFalse()) @@ -340,9 +341,9 @@ var _ = Describe("Header Parsing", func() { 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x1, 0x2, 0x3, 0x4, // connection ID } data = appendPacketNumber(data, 0x99beef, protocol.PacketNumberLen4) - b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b, 10) + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 10) Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) hdr, err := iHdr.Parse(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsLongHeader).To(BeFalse()) diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 47ea6241..03b19fed 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -326,10 +326,9 @@ var _ = Describe("Header", func() { srcConnID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37} data, err := ComposeVersionNegotiation(destConnID, srcConnID, []protocol.VersionNumber{0x12345678, 0x87654321}) Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b, 4) + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 4) Expect(err).ToNot(HaveOccurred()) - hdr, err := iHdr.Parse(b, versionIETFHeader) + hdr, err := iHdr.Parse(bytes.NewReader(data), versionIETFHeader) Expect(err).ToNot(HaveOccurred()) hdr.Log(logger) Expect(buf.String()).To(ContainSubstring("VersionNegotiationPacket{DestConnectionID: 0xdeadbeefcafe1337, SrcConnectionID: 0xdecafbad13371337")) diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index 6c4f82f9..042e9a3a 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -16,9 +16,9 @@ var _ = Describe("Version Negotiation Packets", func() { data, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(err).ToNot(HaveOccurred()) Expect(data[0] & 0x80).ToNot(BeZero()) - b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b, 4) + iHdr, err := ParseInvariantHeader(bytes.NewReader(data), 4) Expect(err).ToNot(HaveOccurred()) + b := bytes.NewReader(data) hdr, err := iHdr.Parse(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsVersionNegotiation).To(BeTrue()) @@ -30,5 +30,6 @@ var _ = Describe("Version Negotiation Packets", func() { for _, version := range versions { Expect(hdr.SupportedVersions).To(ContainElement(version)) } + Expect(b.Len()).To(BeZero()) }) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index deec1e15..bd4358b9 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -205,6 +205,7 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { } h.mutex.RUnlock() + r = bytes.NewReader(data) hdr, err := iHdr.Parse(r, version) if err != nil { return fmt.Errorf("error parsing header: %s", err) diff --git a/packet_packer_test.go b/packet_packer_test.go index 2287e376..c25422af 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -31,9 +31,9 @@ var _ = Describe("Packet packer", func() { ) checkLength := func(data []byte) { - r := bytes.NewReader(data) - iHdr, err := wire.ParseInvariantHeader(r, 0) + iHdr, err := wire.ParseInvariantHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) + r := bytes.NewReader(data) hdr, err := iHdr.Parse(r, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) ExpectWithOffset(0, hdr.Length).To(BeEquivalentTo(r.Len() + int(hdr.PacketNumberLen))) diff --git a/server_test.go b/server_test.go index 865760b0..280cce06 100644 --- a/server_test.go +++ b/server_test.go @@ -99,10 +99,9 @@ var _ = Describe("Server", func() { }) parseHeader := func(data []byte) *wire.Header { - b := bytes.NewReader(data) - iHdr, err := wire.ParseInvariantHeader(b, 0) + iHdr, err := wire.ParseInvariantHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) - hdr, err := iHdr.Parse(b, protocol.VersionTLS) + hdr, err := iHdr.Parse(bytes.NewReader(data), protocol.VersionTLS) Expect(err).ToNot(HaveOccurred()) return hdr }