From 490d9ddab2e737c8e4c2b32a48049fd172b139ef Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 9 Dec 2019 14:56:08 +0800 Subject: [PATCH] refactor parsing of headers with invalid reserved bits --- internal/wire/extended_header.go | 24 +++++++++++------------- internal/wire/header.go | 10 +++++++++- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index af9b94bb..b789a607 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -28,15 +28,15 @@ type ExtendedHeader struct { PacketNumber protocol.PacketNumber } -func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) { +func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) { // read the (now unencrypted) first byte var err error h.typeByte, err = b.ReadByte() if err != nil { - return nil, err + return false, err } if _, err := b.Seek(int64(h.ParsedLen())-1, io.SeekCurrent); err != nil { - return nil, err + return false, err } if h.IsLongHeader { return h.parseLongHeader(b, v) @@ -44,31 +44,29 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*Exte return h.parseShortHeader(b, v) } -func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (*ExtendedHeader, error) { +func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { if err := h.readPacketNumber(b); err != nil { - return nil, err + return false, err } - var err error if h.typeByte&0xc != 0 { - err = ErrInvalidReservedBits + return false, nil } - return h, err + return true, nil } -func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (*ExtendedHeader, error) { +func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { h.KeyPhase = protocol.KeyPhaseZero if h.typeByte&0x4 > 0 { h.KeyPhase = protocol.KeyPhaseOne } if err := h.readPacketNumber(b); err != nil { - return nil, err + return false, err } - var err error if h.typeByte&0x18 != 0 { - err = ErrInvalidReservedBits + return false, nil } - return h, err + return true, nil } func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { diff --git a/internal/wire/header.go b/internal/wire/header.go index b17e8127..6d5e8a23 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -241,7 +241,15 @@ func (h *Header) ParsedLen() protocol.ByteCount { // ParseExtended parses the version dependent part of the header. // The Reader has to be set such that it points to the first byte of the header. func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) { - return h.toExtendedHeader().parse(b, ver) + extHdr := h.toExtendedHeader() + reservedBitsValid, err := extHdr.parse(b, ver) + if err != nil { + return nil, err + } + if !reservedBitsValid { + return extHdr, ErrInvalidReservedBits + } + return extHdr, nil } func (h *Header) toExtendedHeader() *ExtendedHeader {