diff --git a/public_header.go b/public_header.go index 3f8c60b9..0e0f7379 100644 --- a/public_header.go +++ b/public_header.go @@ -3,7 +3,6 @@ package quic import ( "bytes" "errors" - "io" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" @@ -112,7 +111,7 @@ func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pe // ParsePublicHeader parses a QUIC packet's public header // the packetSentBy is the perspective of the peer that sent this PublicHeader, i.e. if we're the server, packetSentBy should be PerspectiveClient -func ParsePublicHeader(b io.ByteReader, packetSentBy protocol.Perspective) (*PublicHeader, error) { +func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*PublicHeader, error) { header := &PublicHeader{} // First byte @@ -180,6 +179,9 @@ func ParsePublicHeader(b io.ByteReader, packetSentBy protocol.Perspective) (*Pub } header.VersionNumber = protocol.VersionTagToNumber(versionTag) } else { // parse the version negotiaton packet + if b.Len()%4 != 0 { + return nil, qerr.InvalidVersionNegotiationPacket + } header.SupportedVersions = make([]protocol.VersionNumber, 0) for { var versionTag uint32 diff --git a/public_header_test.go b/public_header_test.go index 243e6753..fffb9f28 100644 --- a/public_header_test.go +++ b/public_header_test.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -113,6 +114,14 @@ var _ = Describe("Public Header", func() { Expect(hdr.SupportedVersions).To(Equal([]protocol.VersionNumber{protocol.VersionUnsupported, protocol.SupportedVersions[0], protocol.VersionUnsupported})) Expect(b.Len()).To(BeZero()) }) + + It("errors on invalid version tags", func() { + data := composeVersionNegotiation(0x1337) + data = append(data, []byte{0x13, 0x37}...) + b := bytes.NewReader(data) + _, err := ParsePublicHeader(b, protocol.PerspectiveServer) + Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket)) + }) }) Context("Packet Number lengths", func() {