From af157408dec59a0ea1fbd19ee05aff15521450e5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 25 Nov 2018 15:36:59 +0700 Subject: [PATCH] move parsing of extended headers to the same file as the struct --- internal/wire/extended_header.go | 76 ++++++++++++++++++ internal/wire/{header_parser.go => header.go} | 79 +------------------ .../{header_parser_test.go => header_test.go} | 26 +++--- packet_handler_map.go | 2 +- packet_packer_test.go | 2 +- server_test.go | 2 +- 6 files changed, 96 insertions(+), 91 deletions(-) rename internal/wire/{header_parser.go => header.go} (59%) rename internal/wire/{header_parser_test.go => header_test.go} (94%) diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index ee85bdd03..19b216b90 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -4,8 +4,10 @@ import ( "bytes" "crypto/rand" "fmt" + "io" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/utils" ) @@ -22,6 +24,7 @@ type ExtendedHeader struct { PacketNumberLen protocol.PacketNumberLen PacketNumber protocol.PacketNumber + typeByte byte Type protocol.PacketType IsLongHeader bool KeyPhase int @@ -29,6 +32,79 @@ type ExtendedHeader struct { Token []byte } +func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) { + if h.IsLongHeader { + return h.parseLongHeader(b, v) + } + return h.parseShortHeader(b, v) +} + +func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) { + h.Type = protocol.PacketType(h.typeByte & 0x7f) + + if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake { + return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type)) + } + + if h.Type == protocol.PacketTypeRetry { + odcilByte, err := b.ReadByte() + if err != nil { + return nil, err + } + odcil := decodeSingleConnIDLen(odcilByte & 0xf) + h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil) + if err != nil { + return nil, err + } + h.Token = make([]byte, b.Len()) + if _, err := io.ReadFull(b, h.Token); err != nil { + return nil, err + } + return h, nil + } + + if h.Type == protocol.PacketTypeInitial { + tokenLen, err := utils.ReadVarInt(b) + if err != nil { + return nil, err + } + if tokenLen > uint64(b.Len()) { + return nil, io.EOF + } + h.Token = make([]byte, tokenLen) + if _, err := io.ReadFull(b, h.Token); err != nil { + return nil, err + } + } + + pl, err := utils.ReadVarInt(b) + if err != nil { + return nil, err + } + h.Length = protocol.ByteCount(pl) + pn, pnLen, err := utils.ReadVarIntPacketNumber(b) + if err != nil { + return nil, err + } + h.PacketNumber = pn + h.PacketNumberLen = pnLen + + return h, nil +} + +func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) { + h.KeyPhase = int(h.typeByte&0x40) >> 6 + + pn, pnLen, err := utils.ReadVarIntPacketNumber(b) + if err != nil { + return nil, err + } + h.PacketNumber = pn + h.PacketNumberLen = pnLen + + return h, nil +} + // Write writes the Header. func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error { if h.IsLongHeader { diff --git a/internal/wire/header_parser.go b/internal/wire/header.go similarity index 59% rename from internal/wire/header_parser.go rename to internal/wire/header.go index 7b400b39a..d06aabc68 100644 --- a/internal/wire/header_parser.go +++ b/internal/wire/header.go @@ -2,7 +2,6 @@ package wire import ( "bytes" - "fmt" "io" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -95,91 +94,21 @@ func (h *Header) IsVersionNegotiation() bool { return h.IsLongHeader() && h.Version == 0 } -// Parse parses the version dependent part of the header. +// ParseExtended parses the version dependent part of the header. // The Reader has to be set such that it points to the first byte of the header. -func (h *Header) Parse(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) { +func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) { if _, err := b.Seek(int64(h.len), io.SeekCurrent); err != nil { return nil, err } - if h.IsLongHeader() { - return h.parseLongHeader(b, ver) - } - return h.parseShortHeader(b, ver) + return h.toExtendedHeader().parse(b, ver) } func (h *Header) toExtendedHeader() *ExtendedHeader { return &ExtendedHeader{ IsLongHeader: h.IsLongHeader(), + typeByte: h.typeByte, DestConnectionID: h.DestConnectionID, SrcConnectionID: h.SrcConnectionID, Version: h.Version, } } - -func (h *Header) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) { - eh := h.toExtendedHeader() - eh.Type = protocol.PacketType(h.typeByte & 0x7f) - - if eh.Type != protocol.PacketTypeInitial && eh.Type != protocol.PacketTypeRetry && eh.Type != protocol.PacketType0RTT && eh.Type != protocol.PacketTypeHandshake { - return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", eh.Type)) - } - - if eh.Type == protocol.PacketTypeRetry { - odcilByte, err := b.ReadByte() - if err != nil { - return nil, err - } - odcil := decodeSingleConnIDLen(odcilByte & 0xf) - eh.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil) - if err != nil { - return nil, err - } - eh.Token = make([]byte, b.Len()) - if _, err := io.ReadFull(b, eh.Token); err != nil { - return nil, err - } - return eh, nil - } - - if eh.Type == protocol.PacketTypeInitial { - tokenLen, err := utils.ReadVarInt(b) - if err != nil { - return nil, err - } - if tokenLen > uint64(b.Len()) { - return nil, io.EOF - } - eh.Token = make([]byte, tokenLen) - if _, err := io.ReadFull(b, eh.Token); err != nil { - return nil, err - } - } - - pl, err := utils.ReadVarInt(b) - if err != nil { - return nil, err - } - eh.Length = protocol.ByteCount(pl) - pn, pnLen, err := utils.ReadVarIntPacketNumber(b) - if err != nil { - return nil, err - } - eh.PacketNumber = pn - eh.PacketNumberLen = pnLen - - return eh, nil -} - -func (h *Header) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) { - eh := h.toExtendedHeader() - eh.KeyPhase = int(h.typeByte&0x40) >> 6 - - pn, pnLen, err := utils.ReadVarIntPacketNumber(b) - if err != nil { - return nil, err - } - eh.PacketNumber = pn - eh.PacketNumberLen = pnLen - - return eh, nil -} diff --git a/internal/wire/header_parser_test.go b/internal/wire/header_test.go similarity index 94% rename from internal/wire/header_parser_test.go rename to internal/wire/header_test.go index 31f0ae2ec..c7e6a7371 100644 --- a/internal/wire/header_parser_test.go +++ b/internal/wire/header_test.go @@ -85,7 +85,7 @@ var _ = Describe("Header Parsing", func() { Expect(hdr.DestConnectionID).To(Equal(destConnID)) Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) b := bytes.NewReader(data) - extHdr, err := hdr.Parse(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.Type).To(Equal(protocol.PacketTypeInitial)) Expect(extHdr.IsLongHeader).To(BeTrue()) @@ -142,7 +142,7 @@ var _ = Describe("Header Parsing", func() { hdr, err := ParseHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) - extHdr, err := hdr.Parse(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x123))) Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) @@ -161,7 +161,7 @@ var _ = Describe("Header Parsing", func() { hdr, err := ParseHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) - extHdr, err := hdr.Parse(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(extHdr.OrigDestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) @@ -183,7 +183,7 @@ var _ = Describe("Header Parsing", func() { b := bytes.NewReader(buf.Bytes()) hdr, err := ParseHeader(b, 0) Expect(err).ToNot(HaveOccurred()) - _, err = hdr.Parse(b, versionIETFFrames) + _, err = hdr.ParseExtended(b, versionIETFFrames) Expect(err).To(MatchError("InvalidPacketHeader: Received packet with invalid packet type: 42")) }) @@ -200,7 +200,7 @@ var _ = Describe("Header Parsing", func() { b := bytes.NewReader(data) hdr, err := ParseHeader(b, 0) Expect(err).ToNot(HaveOccurred()) - _, err = hdr.Parse(b, versionIETFFrames) + _, err = hdr.ParseExtended(b, versionIETFFrames) Expect(err).To(MatchError(io.EOF)) }) @@ -231,7 +231,7 @@ var _ = Describe("Header Parsing", func() { b := bytes.NewReader(data[:i]) hdr, err := ParseHeader(b, 0) Expect(err).ToNot(HaveOccurred()) - _, err = hdr.Parse(b, versionIETFFrames) + _, err = hdr.ParseExtended(b, versionIETFFrames) Expect(err).To(Equal(io.EOF)) } }) @@ -251,7 +251,7 @@ var _ = Describe("Header Parsing", func() { b := bytes.NewReader(data[:i]) hdr, err := ParseHeader(b, 0) Expect(err).ToNot(HaveOccurred()) - _, err = hdr.Parse(b, versionIETFFrames) + _, err = hdr.ParseExtended(b, versionIETFFrames) Expect(err).To(Equal(io.EOF)) } }) @@ -268,7 +268,7 @@ var _ = Describe("Header Parsing", func() { Expect(hdr.IsVersionNegotiation()).To(BeFalse()) Expect(hdr.DestConnectionID).To(Equal(connID)) b := bytes.NewReader(data) - extHdr, err := hdr.Parse(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.KeyPhase).To(Equal(0)) Expect(extHdr.DestConnectionID).To(Equal(connID)) @@ -286,7 +286,7 @@ var _ = Describe("Header Parsing", func() { Expect(hdr.IsLongHeader()).To(BeFalse()) Expect(hdr.DestConnectionID).To(Equal(connID)) b := bytes.NewReader(data) - extHdr, err := hdr.Parse(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.KeyPhase).To(Equal(0)) Expect(extHdr.DestConnectionID).To(Equal(connID)) @@ -304,7 +304,7 @@ var _ = Describe("Header Parsing", func() { Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsLongHeader()).To(BeFalse()) b := bytes.NewReader(data) - extHdr, err := hdr.Parse(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.KeyPhase).To(Equal(1)) Expect(b.Len()).To(BeZero()) @@ -319,7 +319,7 @@ var _ = Describe("Header Parsing", func() { hdr, err := ParseHeader(bytes.NewReader(data), 4) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) - extHdr, err := hdr.Parse(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.IsLongHeader).To(BeFalse()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) @@ -336,7 +336,7 @@ var _ = Describe("Header Parsing", func() { hdr, err := ParseHeader(bytes.NewReader(data), 10) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) - extHdr, err := hdr.Parse(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.IsLongHeader).To(BeFalse()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x99beef))) @@ -366,7 +366,7 @@ var _ = Describe("Header Parsing", func() { b := bytes.NewReader(data[:i]) hdr, err := ParseHeader(b, 6) Expect(err).ToNot(HaveOccurred()) - _, err = hdr.Parse(b, versionIETFFrames) + _, err = hdr.ParseExtended(b, versionIETFFrames) Expect(err).To(Equal(io.EOF)) } }) diff --git a/packet_handler_map.go b/packet_handler_map.go index 08636067c..57994bf11 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -210,7 +210,7 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { if !hdr.IsVersionNegotiation() { r = bytes.NewReader(data) var err error - extHdr, err = hdr.Parse(r, version) + extHdr, err = hdr.ParseExtended(r, version) if err != nil { return fmt.Errorf("error parsing extended header: %s", err) } diff --git a/packet_packer_test.go b/packet_packer_test.go index 8e4c09a08..02e900b0e 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -34,7 +34,7 @@ var _ = Describe("Packet packer", func() { hdr, err := wire.ParseHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(data) - extHdr, err := hdr.Parse(r, protocol.VersionWhatever) + extHdr, err := hdr.ParseExtended(r, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) ExpectWithOffset(0, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen))) } diff --git a/server_test.go b/server_test.go index ace66fa3b..abd8a6d55 100644 --- a/server_test.go +++ b/server_test.go @@ -101,7 +101,7 @@ var _ = Describe("Server", func() { parseHeader := func(data []byte) *wire.ExtendedHeader { hdr, err := wire.ParseHeader(bytes.NewReader(data), 0) Expect(err).ToNot(HaveOccurred()) - extHdr, err := hdr.Parse(bytes.NewReader(data), protocol.VersionTLS) + extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.VersionTLS) Expect(err).ToNot(HaveOccurred()) return extHdr }