From 14426dfa1239022c8e6f63b6679b1ea5108521e2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 15 Jan 2019 16:16:42 +0700 Subject: [PATCH] implement a function to parse the destination connection ID of a packet --- internal/wire/header.go | 24 ++++++++++++ internal/wire/header_test.go | 75 ++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/internal/wire/header.go b/internal/wire/header.go index 376b362a0..9461cccb3 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -11,6 +11,30 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) +// ParseConnectionID parses the destination connection ID of a packet. +// It uses the data slice for the connection ID. +// That means that the connection ID must not be used after the packet buffer is released. +func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { + if len(data) == 0 { + return nil, io.EOF + } + isLongHeader := data[0]&0x80 > 0 + if !isLongHeader { + if len(data) < shortHeaderConnIDLen+1 { + return nil, io.EOF + } + return protocol.ConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil + } + if len(data) < 6 { + return nil, io.EOF + } + destConnIDLen, _ := decodeConnIDLen(data[5]) + if len(data) < 6+destConnIDLen { + return nil, io.EOF + } + return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil +} + var errUnsupportedVersion = errors.New("unsupported version") // The Header is the version independent part of the header diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index eb6202508..b1c8702f4 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -19,6 +19,81 @@ var _ = Describe("Header Parsing", func() { return data } + Context("Parsing the Connection ID", func() { + It("parses the connection ID of a long header packet", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, + Version: versionIETFFrames, + }, + PacketNumberLen: 2, + }).Write(buf, versionIETFFrames)).To(Succeed()) + connID, err := ParseConnectionID(buf.Bytes(), 8) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + }) + + It("parses the connection ID of a short header packet", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + }, + PacketNumberLen: 2, + }).Write(buf, versionIETFFrames)).To(Succeed()) + buf.Write([]byte("foobar")) + connID, err := ParseConnectionID(buf.Bytes(), 4) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + }) + + It("errors on EOF, for short header packets", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + }, + PacketNumberLen: 2, + }).Write(buf, versionIETFFrames)).To(Succeed()) + data := buf.Bytes()[:buf.Len()-2] // cut the packet number + _, err := ParseConnectionID(data, 8) + Expect(err).ToNot(HaveOccurred()) + for i := 0; i < len(data); i++ { + b := make([]byte, i) + copy(b, data[:i]) + _, err := ParseConnectionID(b, 8) + Expect(err).To(MatchError(io.EOF)) + } + }) + + It("errors on EOF, for long header packets", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 8, 9}, + Version: versionIETFFrames, + }, + PacketNumberLen: 2, + }).Write(buf, versionIETFFrames)).To(Succeed()) + data := buf.Bytes()[:buf.Len()-2] // cut the packet number + _, err := ParseConnectionID(data, 8) + Expect(err).ToNot(HaveOccurred()) + for i := 0; i < 1 /* first byte */ +4 /* version */ +1 /* conn ID lengths */ +6; /* dest conn ID */ i++ { + b := make([]byte, i) + copy(b, data[:i]) + _, err := ParseConnectionID(b, 8) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + Context("Version Negotiation Packets", func() { It("parses", func() { srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}