From be59be9ef155500799a66f4752fb3d81c782805a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 9 Dec 2019 15:42:37 +0800 Subject: [PATCH] use the parsed header length when unpacking packets --- internal/wire/extended_header.go | 21 ++++++++++++++++++--- internal/wire/header_test.go | 4 ++++ packet_unpacker.go | 10 +++++----- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index b789a607b..472be9929 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -26,22 +26,32 @@ type ExtendedHeader struct { PacketNumberLen protocol.PacketNumberLen PacketNumber protocol.PacketNumber + + parsedLen protocol.ByteCount } func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) { + startLen := b.Len() // read the (now unencrypted) first byte var err error h.typeByte, err = b.ReadByte() if err != nil { return false, err } - if _, err := b.Seek(int64(h.ParsedLen())-1, io.SeekCurrent); err != nil { + if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil { return false, err } + var reservedBitsValid bool if h.IsLongHeader { - return h.parseLongHeader(b, v) + reservedBitsValid, err = h.parseLongHeader(b, v) + } else { + reservedBitsValid, err = h.parseShortHeader(b, v) } - return h.parseShortHeader(b, v) + if err != nil { + return false, err + } + h.parsedLen = protocol.ByteCount(startLen - b.Len()) + return reservedBitsValid, err } func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { @@ -186,6 +196,11 @@ func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error { return nil } +// ParsedLen returns the number of bytes that were consumed when parsing the header +func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { + return h.parsedLen +} + // GetLength determines the length of the Header. func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount { if h.IsLongHeader { diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index b60761930..36bb6104e 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -188,6 +188,7 @@ var _ = Describe("Header Parsing", func() { Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) Expect(b.Len()).To(Equal(6)) // foobar Expect(hdr.ParsedLen()).To(BeEquivalentTo(hdrLen)) + Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 4)) }) It("errors if 0x40 is not set", func() { @@ -382,6 +383,7 @@ var _ = Describe("Header Parsing", func() { Expect(data).To(Equal(append(hdrRaw, []byte("foobar")...))) Expect(rest).To(Equal([]byte("raboof"))) }) + It("errors on packets that are smaller than the length in the packet header, for too small packet number", func() { buf := &bytes.Buffer{} Expect((&ExtendedHeader{ @@ -438,6 +440,8 @@ var _ = Describe("Header Parsing", func() { Expect(extHdr.DestConnectionID).To(Equal(connID)) Expect(extHdr.SrcConnectionID).To(BeEmpty()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) + Expect(hdr.ParsedLen()).To(BeEquivalentTo(len(data) - 1)) + Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 1)) Expect(pdata).To(Equal(data)) Expect(rest).To(BeEmpty()) }) diff --git a/packet_unpacker.go b/packet_unpacker.go index 0736d86a3..475c44c29 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -99,7 +99,7 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { return nil, nil, fmt.Errorf("error parsing extended header: %s", parseErr) } - extHdrLen := extHdr.GetLength(u.version) + extHdrLen := extHdr.ParsedLen() decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen]) if err != nil { return nil, nil, err @@ -123,7 +123,7 @@ func (u *packetUnpacker) unpackShortHeaderPacket( if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { return nil, nil, parseErr } - extHdrLen := extHdr.GetLength(u.version) + extHdrLen := extHdr.ParsedLen() decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen]) if err != nil { return nil, nil, err @@ -137,10 +137,10 @@ func (u *packetUnpacker) unpackShortHeaderPacket( func (u *packetUnpacker) unpack(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { r := bytes.NewReader(data) - hdrLen := int(hdr.ParsedLen()) - if len(data) < hdrLen+4+16 { + hdrLen := hdr.ParsedLen() + if protocol.ByteCount(len(data)) < hdrLen+4+16 { //nolint:stylecheck - return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", len(data)-hdrLen) + return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen) } // The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it. // 1. save a copy of the 4 bytes