From 28ed85b9c607072d592685d591c19c4c206e1101 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 15 Jan 2019 15:37:35 +0700 Subject: [PATCH 1/4] move cutting of coalesced packets to the wire package --- client_test.go | 2 +- internal/wire/header.go | 40 ++++-- internal/wire/header_test.go | 159 +++++++++++++++------- internal/wire/version_negotiation_test.go | 8 +- packet_handler_map.go | 19 +-- packet_handler_map_test.go | 16 +-- packet_packer_test.go | 4 +- packet_unpacker_test.go | 12 +- server.go | 10 +- server_test.go | 8 +- session_test.go | 2 +- 11 files changed, 173 insertions(+), 107 deletions(-) diff --git a/client_test.go b/client_test.go index ab20ebf66..d0df05fa3 100644 --- a/client_test.go +++ b/client_test.go @@ -58,7 +58,7 @@ var _ = Describe("Client", func() { composeVersionNegotiationPacket := func(connID protocol.ConnectionID, versions []protocol.VersionNumber) *receivedPacket { data, err := wire.ComposeVersionNegotiation(connID, nil, versions) Expect(err).ToNot(HaveOccurred()) - hdr, err := wire.ParseHeader(bytes.NewReader(data), 0) + hdr, _, _, err := wire.ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsVersionNegotiation()).To(BeTrue()) return &receivedPacket{ diff --git a/internal/wire/header.go b/internal/wire/header.go index 5b0d6effb..376b362a0 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -3,6 +3,7 @@ package wire import ( "bytes" "errors" + "fmt" "io" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -10,6 +11,8 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) +var errUnsupportedVersion = errors.New("unsupported version") + // The Header is the version independent part of the header type Header struct { Version protocol.VersionNumber @@ -28,19 +31,43 @@ type Header struct { parsedLen protocol.ByteCount // how many bytes were read while parsing this header } +// ParsePacket parses a packet. +// If the packet has a long header, the packet is cut according to the length field. +// If we understand the version, the packet is header up unto the packet number. +// Otherwise, only the invariant part of the header is parsed. +func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* packet data */, []byte /* rest */, error) { + hdr, err := parseHeader(bytes.NewReader(data), shortHeaderConnIDLen) + if err != nil { + if err == errUnsupportedVersion { + return hdr, nil, nil, nil + } + return nil, nil, nil, err + } + var rest []byte + if hdr.IsLongHeader { + if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length { + return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) + } + packetLen := int(hdr.ParsedLen() + hdr.Length) + rest = data[packetLen:] + data = data[:packetLen] + } + return hdr, data, rest, nil +} + // ParseHeader parses the header. // For short header packets: up to the packet number. // For long header packets: // * if we understand the version: up to the packet number // * if not, only the invariant part of the header -func ParseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { +func parseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { startLen := b.Len() h, err := parseHeaderImpl(b, shortHeaderConnIDLen) if err != nil { - return nil, err + return h, err } h.parsedLen = protocol.ByteCount(startLen - b.Len()) - return h, nil + return h, err } func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { @@ -63,10 +90,7 @@ func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) } return h, nil } - if err := h.parseLongHeader(b); err != nil { - return nil, err - } - return h, nil + return h, h.parseLongHeader(b) } func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error { @@ -102,7 +126,7 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error { } // If we don't understand the version, we have no idea how to interpret the rest of the bytes if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) { - return nil + return errUnsupportedVersion } switch (h.typeByte & 0x30) >> 4 { diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 6d138322a..eb6202508 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -24,10 +24,9 @@ var _ = Describe("Header Parsing", func() { srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} versions := []protocol.VersionNumber{0x22334455, 0x33445566} - data, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) + vnp, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(err).ToNot(HaveOccurred()) - b := bytes.NewReader(data) - hdr, err := ParseHeader(b, 0) + hdr, _, rest, err := ParsePacket(vnp, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.DestConnectionID).To(Equal(destConnID)) Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) @@ -37,7 +36,7 @@ var _ = Describe("Header Parsing", func() { for _, v := range versions { Expect(hdr.SupportedVersions).To(ContainElement(v)) } - Expect(b.Len()).To(BeZero()) + Expect(rest).To(BeEmpty()) }) It("errors if it contains versions of the wrong length", func() { @@ -45,8 +44,7 @@ var _ = Describe("Header Parsing", func() { versions := []protocol.VersionNumber{0x22334455, 0x33445566} data, err := ComposeVersionNegotiation(connID, connID, versions) Expect(err).ToNot(HaveOccurred()) - data = data[:len(data)-2] - _, err = ParseHeader(bytes.NewReader(data), 0) + _, _, _, err = ParsePacket(data[:len(data)-2], 0) Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket)) }) @@ -57,7 +55,7 @@ var _ = Describe("Header Parsing", func() { Expect(err).ToNot(HaveOccurred()) // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number data = data[:len(data)-8] - _, err = ParseHeader(bytes.NewReader(data), 0) + _, _, _, err = ParsePacket(data, 0) Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list")) }) }) @@ -71,28 +69,31 @@ var _ = Describe("Header Parsing", func() { data = append(data, 0x61) // connection ID lengths data = append(data, destConnID...) data = append(data, srcConnID...) - data = append(data, encodeVarInt(6)...) // token length - data = append(data, []byte("foobar")...) // token - data = append(data, encodeVarInt(0x1337)...) // length + data = append(data, encodeVarInt(6)...) // token length + data = append(data, []byte("foobar")...) // token + data = append(data, encodeVarInt(10)...) // length hdrLen := len(data) - data = append(data, []byte{0, 0, 0xbe, 0xef}...) + data = append(data, []byte{0, 0, 0xbe, 0xef}...) // packet number + data = append(data, []byte("foobar")...) - hdr, err := ParseHeader(bytes.NewReader(data), 0) + hdr, pdata, rest, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) + Expect(pdata).To(Equal(data)) Expect(hdr.IsLongHeader).To(BeTrue()) Expect(hdr.IsVersionNegotiation()).To(BeFalse()) Expect(hdr.DestConnectionID).To(Equal(destConnID)) Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) Expect(hdr.Token).To(Equal([]byte("foobar"))) - Expect(hdr.Length).To(Equal(protocol.ByteCount(0x1337))) + Expect(hdr.Length).To(Equal(protocol.ByteCount(10))) Expect(hdr.Version).To(Equal(versionIETFFrames)) + Expect(rest).To(BeEmpty()) b := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0xbeef))) Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) - Expect(b.Len()).To(BeZero()) + Expect(b.Len()).To(Equal(6)) // foobar Expect(hdr.ParsedLen()).To(BeEquivalentTo(hdrLen)) }) @@ -103,7 +104,7 @@ var _ = Describe("Header Parsing", func() { 0xde, 0xca, 0xfb, 0xad, // dest conn ID 0xde, 0xad, 0xbe, 0xef, // src conn ID } - _, err := ParseHeader(bytes.NewReader(data), 0) + _, _, _, err := ParsePacket(data, 0) Expect(err).To(MatchError("not a QUIC packet")) }) @@ -116,14 +117,13 @@ var _ = Describe("Header Parsing", func() { 0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, 'f', 'o', 'o', 'b', 'a', 'r', // unspecified bytes } - b := bytes.NewReader(data) - hdr, err := ParseHeader(b, 0) + hdr, _, rest, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsLongHeader).To(BeTrue()) Expect(hdr.Version).To(Equal(protocol.VersionNumber(0xdeadbeef))) Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8})) Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1})) - Expect(b.Len()).To(Equal(6)) + Expect(rest).To(BeEmpty()) }) It("parses a Long Header without a destination connection ID", func() { @@ -131,9 +131,9 @@ var _ = Describe("Header Parsing", func() { data = appendVersion(data, versionIETFFrames) data = append(data, 0x01) // connection ID lengths data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // source connection ID - data = append(data, encodeVarInt(0x42)...) // length + data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) - hdr, err := ParseHeader(bytes.NewReader(data), 0) + hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketType0RTT)) Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) @@ -145,9 +145,9 @@ var _ = Describe("Header Parsing", func() { data = appendVersion(data, versionIETFFrames) data = append(data, 0x70) // connection ID lengths data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID - data = append(data, encodeVarInt(0x42)...) // length + data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) - hdr, err := ParseHeader(bytes.NewReader(data), 0) + hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.SrcConnectionID).To(BeEmpty()) Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) @@ -158,12 +158,11 @@ var _ = Describe("Header Parsing", func() { data = appendVersion(data, versionIETFFrames) // version number data = append(data, 0x0) // connection ID lengths data = append(data, encodeVarInt(0)...) // token length - data = append(data, encodeVarInt(0x42)...) // length + data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0x1, 0x23}...) - hdr, err := ParseHeader(bytes.NewReader(data), 0) + hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.Length).To(BeEquivalentTo(0x42)) b := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) @@ -178,12 +177,13 @@ var _ = Describe("Header Parsing", func() { data = append(data, 0x0) // connection ID lengths data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token - b := bytes.NewReader(data) - hdr, err := ParseHeader(b, 0) + hdr, pdata, rest, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(hdr.OrigDestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) Expect(hdr.Token).To(Equal([]byte("foobar"))) + Expect(pdata).To(Equal(data)) + Expect(rest).To(BeEmpty()) }) It("errors if the token length is too large", func() { @@ -194,17 +194,16 @@ var _ = Describe("Header Parsing", func() { data = append(data, encodeVarInt(0x42)...) // length, 1 byte data = append(data, []byte{0x12, 0x34}...) // packet number - _, err := ParseHeader(bytes.NewReader(data), 0) + _, _, _, err := ParsePacket(data, 0) Expect(err).To(MatchError(io.EOF)) }) It("errors if the 5th or 6th bit are set", func() { - data := []byte{0xc0 | 0x2<<4 | 0x8 /* set the 5th bit */} + data := []byte{0xc0 | 0x2<<4 /* set the 5th bit */ | 0x8} data = appendVersion(data, versionIETFFrames) data = append(data, 0x0) // connection ID lengths - data = append(data, 0x42) // packet number - data = append(data, encodeVarInt(1)...) // length - hdr, err := ParseHeader(bytes.NewReader(data), 0) + data = append(data, encodeVarInt(0)...) // length + hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) _, err = hdr.ParseExtended(bytes.NewReader(data), versionIETFFrames) @@ -218,7 +217,7 @@ var _ = Describe("Header Parsing", func() { data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // destination connection ID data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // source connection ID for i := 0; i < len(data); i++ { - _, err := ParseHeader(bytes.NewReader(data[:i]), 0) + _, _, _, err := ParsePacket(data[:i], 0) Expect(err).To(Equal(io.EOF)) } }) @@ -226,13 +225,13 @@ var _ = Describe("Header Parsing", func() { It("errors on EOF, when parsing the extended header", func() { data := []byte{0xc0 | 0x2<<4 | 0x3} data = appendVersion(data, versionIETFFrames) - data = append(data, 0x0) // connection ID lengths - data = append(data, encodeVarInt(0x1337)...) + data = append(data, 0x0) // connection ID lengths + data = append(data, encodeVarInt(0)...) // length hdrLen := len(data) data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number for i := hdrLen; i < len(data); i++ { data = data[:i] - hdr, err := ParseHeader(bytes.NewReader(data), 0) + hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) _, err = hdr.ParseExtended(b, versionIETFFrames) @@ -249,13 +248,75 @@ var _ = Describe("Header Parsing", func() { hdrLen := len(data) for i := hdrLen; i < len(data); i++ { data = data[:i] - hdr, err := ParseHeader(bytes.NewReader(data), 0) + hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) _, err = hdr.ParseExtended(b, versionIETFFrames) Expect(err).To(Equal(io.EOF)) } }) + + Context("coalesced packets", func() { + It("cuts packets", func() { + buf := &bytes.Buffer{} + hdr := Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Length: 2 + 6, + Version: versionIETFFrames, + } + Expect((&ExtendedHeader{ + Header: hdr, + PacketNumber: 0x1337, + PacketNumberLen: 2, + }).Write(buf, versionIETFFrames)).To(Succeed()) + hdrRaw := append([]byte{}, buf.Bytes()...) + buf.Write([]byte("foobar")) // payload of the first packet + buf.Write([]byte("raboof")) // second packet + parsedHdr, data, rest, err := ParsePacket(buf.Bytes(), 4) + Expect(err).ToNot(HaveOccurred()) + Expect(parsedHdr.Type).To(Equal(hdr.Type)) + Expect(parsedHdr.DestConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(data).To(Equal(append(hdrRaw, []byte("foobar")...))) + Expect(rest).To(Equal([]byte("raboof"))) + }) + It("errors on packets that are smaller than the length in the packet header, for too small packet number", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Length: 3, + Version: versionIETFFrames, + }, + PacketNumber: 0x1337, + PacketNumberLen: 2, + }).Write(buf, versionIETFFrames)).To(Succeed()) + _, _, _, err := ParsePacket(buf.Bytes(), 4) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("packet length (2 bytes) is smaller than the expected length (3 bytes)")) + }) + + It("errors on packets that are smaller than the length in the packet header, for too small payload", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Length: 1000, + Version: versionIETFFrames, + }, + PacketNumber: 0x1337, + PacketNumberLen: 2, + }).Write(buf, versionIETFFrames)).To(Succeed()) + buf.Write(make([]byte, 500-2 /* for packet number length */)) + _, _, _, err := ParsePacket(buf.Bytes(), 4) + Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) + }) + }) }) Context("Short Headers", func() { @@ -263,7 +324,7 @@ var _ = Describe("Header Parsing", func() { connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} data := append([]byte{0x40}, connID...) data = append(data, 0x42) // packet number - hdr, err := ParseHeader(bytes.NewReader(data), 8) + hdr, pdata, rest, err := ParsePacket(data, 8) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsLongHeader).To(BeFalse()) Expect(hdr.IsVersionNegotiation()).To(BeFalse()) @@ -275,20 +336,21 @@ var _ = Describe("Header Parsing", func() { Expect(extHdr.DestConnectionID).To(Equal(connID)) Expect(extHdr.SrcConnectionID).To(BeEmpty()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) - Expect(b.Len()).To(BeZero()) + Expect(pdata).To(Equal(data)) + Expect(rest).To(BeEmpty()) }) It("errors if 0x40 is not set", func() { connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} data := append([]byte{0x0}, connID...) - _, err := ParseHeader(bytes.NewReader(data), 8) + _, _, _, err := ParsePacket(data, 8) Expect(err).To(MatchError("not a QUIC packet")) }) It("errors if the 4th or 5th bit are set", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5} data := append([]byte{0x40 | 0x10 /* set the 4th bit */}, connID...) - hdr, err := ParseHeader(bytes.NewReader(data), 5) + hdr, _, _, err := ParsePacket(data, 5) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsLongHeader).To(BeFalse()) _, err = hdr.ParseExtended(bytes.NewReader(data), versionIETFFrames) @@ -299,8 +361,9 @@ var _ = Describe("Header Parsing", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5} data := append([]byte{0x40}, connID...) data = append(data, 0x42) // packet number - hdr, err := ParseHeader(bytes.NewReader(data), 5) + hdr, pdata, rest, err := ParsePacket(data, 5) Expect(err).ToNot(HaveOccurred()) + Expect(pdata).To(HaveLen(len(data))) Expect(hdr.IsLongHeader).To(BeFalse()) Expect(hdr.DestConnectionID).To(Equal(connID)) b := bytes.NewReader(data) @@ -309,7 +372,7 @@ var _ = Describe("Header Parsing", func() { Expect(extHdr.KeyPhase).To(Equal(0)) Expect(extHdr.DestConnectionID).To(Equal(connID)) Expect(extHdr.SrcConnectionID).To(BeEmpty()) - Expect(b.Len()).To(BeZero()) + Expect(rest).To(BeEmpty()) }) It("reads the Key Phase Bit", func() { @@ -318,7 +381,7 @@ var _ = Describe("Header Parsing", func() { 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID } data = append(data, 11) // packet number - hdr, err := ParseHeader(bytes.NewReader(data), 6) + hdr, _, _, err := ParsePacket(data, 6) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsLongHeader).To(BeFalse()) b := bytes.NewReader(data) @@ -334,7 +397,7 @@ var _ = Describe("Header Parsing", func() { 0xde, 0xad, 0xbe, 0xef, // connection ID } data = append(data, []byte{0x13, 0x37}...) // packet number - hdr, err := ParseHeader(bytes.NewReader(data), 4) + hdr, _, _, err := ParsePacket(data, 4) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(b, versionIETFFrames) @@ -351,7 +414,7 @@ var _ = Describe("Header Parsing", func() { 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x1, 0x2, 0x3, 0x4, // connection ID } data = append(data, []byte{0x99, 0xbe, 0xef}...) // packet number - hdr, err := ParseHeader(bytes.NewReader(data), 10) + hdr, _, _, err := ParsePacket(data, 10) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(b, versionIETFFrames) @@ -369,7 +432,7 @@ var _ = Describe("Header Parsing", func() { } for i := 0; i < len(data); i++ { data = data[:i] - _, err := ParseHeader(bytes.NewReader(data), 8) + _, _, _, err := ParsePacket(data, 8) Expect(err).To(Equal(io.EOF)) } }) @@ -383,7 +446,7 @@ var _ = Describe("Header Parsing", func() { data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number for i := hdrLen; i < len(data); i++ { data = data[:i] - hdr, err := ParseHeader(bytes.NewReader(data), 6) + hdr, _, _, err := ParsePacket(data, 6) Expect(err).ToNot(HaveOccurred()) _, err = hdr.ParseExtended(bytes.NewReader(data), versionIETFFrames) Expect(err).To(Equal(io.EOF)) diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index 8c2933495..1d93c389a 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -1,8 +1,6 @@ package wire import ( - "bytes" - "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -16,9 +14,7 @@ var _ = Describe("Version Negotiation Packets", func() { data, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(err).ToNot(HaveOccurred()) Expect(data[0] & 0x80).ToNot(BeZero()) - b := bytes.NewReader(data) - hdr, err := ParseHeader(b, 4) - Expect(err).ToNot(HaveOccurred()) + hdr, _, rest, err := ParsePacket(data, 4) Expect(err).ToNot(HaveOccurred()) Expect(hdr.DestConnectionID).To(Equal(destConnID)) Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) @@ -28,6 +24,6 @@ var _ = Describe("Version Negotiation Packets", func() { for _, version := range versions { Expect(hdr.SupportedVersions).To(ContainElement(version)) } - Expect(b.Len()).To(BeZero()) + Expect(rest).To(BeEmpty()) }) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index b526b6d95..eb79316c0 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -1,7 +1,6 @@ package quic import ( - "bytes" "errors" "fmt" "net" @@ -199,26 +198,16 @@ func (h *packetHandlerMap) parsePacket( var counter int var lastConnID protocol.ConnectionID for len(data) > 0 { - hdr, err := wire.ParseHeader(bytes.NewReader(data), h.connIDLen) - // drop the packet if we can't parse the header + hdr, packetData, rest, err := wire.ParsePacket(data, h.connIDLen) if err != nil { - return packets, fmt.Errorf("error parsing header: %s", err) + return packets, fmt.Errorf("error parsing packet: %s", err) } + if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { return packets, fmt.Errorf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID) } lastConnID = hdr.DestConnectionID - var rest []byte - if hdr.IsLongHeader { - if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length { - return packets, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) - } - packetLen := int(hdr.ParsedLen() + hdr.Length) - rest = data[packetLen:] - data = data[:packetLen] - } - if counter > 0 { buffer.Split() } @@ -227,7 +216,7 @@ func (h *packetHandlerMap) parsePacket( remoteAddr: addr, hdr: hdr, rcvTime: rcvTime, - data: data, + data: packetData, buffer: buffer, }) diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 22a412f6d..8f0b4e8cb 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -107,7 +107,7 @@ var _ = Describe("Packet Handler Map", func() { It("drops unparseable packets", func() { _, err := handler.parsePacket(nil, nil, []byte{0, 1, 2, 3}) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("error parsing header:")) + Expect(err.Error()).To(ContainSubstring("error parsing packet:")) }) It("deletes removed session immediately", func() { @@ -161,20 +161,6 @@ var _ = Describe("Packet Handler Map", func() { }) Context("coalesced packets", func() { - It("errors on packets that are smaller than the length in the packet header, for too small packet number", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - data := getPacketWithLength(connID, 3) // gets a packet with a 2 byte packet number - _, err := handler.parsePacket(nil, nil, data) - Expect(err).To(MatchError("packet length (2 bytes) is smaller than the expected length (3 bytes)")) - }) - - It("errors on packets that are smaller than the length in the packet header, for too small payload", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - data := append(getPacketWithLength(connID, 1000), make([]byte, 500-2 /* for packet number length */)...) - _, err := handler.parsePacket(nil, nil, data) - Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) - }) - It("cuts packets to the right length", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} data := append(getPacketWithLength(connID, 456), make([]byte, 1000)...) diff --git a/packet_packer_test.go b/packet_packer_test.go index d488c5f60..caf2b04a8 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -29,7 +29,7 @@ var _ = Describe("Packet packer", func() { ) checkLength := func(data []byte) { - hdr, err := wire.ParseHeader(bytes.NewReader(data), 0) + hdr, _, _, err := wire.ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(r, protocol.VersionWhatever) @@ -808,7 +808,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()] - hdr, err := wire.ParseHeader(bytes.NewReader(packet.raw), len(packer.destConnID)) + hdr, _, _, err := wire.ParsePacket(packet.raw, len(packer.destConnID)) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(packet.raw) extHdr, err := hdr.ParseExtended(r, packer.version) diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index af5ba3346..579ff7f00 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -25,10 +25,14 @@ var _ = Describe("Packet Unpacker", func() { getHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) { buf := &bytes.Buffer{} - Expect(extHdr.Write(buf, protocol.VersionWhatever)).To(Succeed()) - hdr, err := wire.ParseHeader(bytes.NewReader(buf.Bytes()), connID.Len()) - Expect(err).ToNot(HaveOccurred()) - return hdr, buf.Bytes() + ExpectWithOffset(1, extHdr.Write(buf, protocol.VersionWhatever)).To(Succeed()) + hdrLen := buf.Len() + if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) { + buf.Write(make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen))) + } + hdr, _, _, err := wire.ParsePacket(buf.Bytes(), connID.Len()) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + return hdr, buf.Bytes()[:hdrLen] } BeforeEach(func() { diff --git a/server.go b/server.go index 9489fb13a..de6fc86b1 100644 --- a/server.go +++ b/server.go @@ -336,8 +336,14 @@ func (s *server) handlePacket(p *receivedPacket) { return } - // TODO(#943): send Stateless Reset - p.buffer.Release() + defer p.buffer.Release() + // Drop long header packets. + // There's litte point in sending a Stateless Reset, since the client + // might not have received the token yet. + if hdr.IsLongHeader { + return + } + } func (s *server) handleInitial(p *receivedPacket) { diff --git a/server_test.go b/server_test.go index 5091a450b..cfa36cf1d 100644 --- a/server_test.go +++ b/server_test.go @@ -118,7 +118,7 @@ var _ = Describe("Server", func() { }) parseHeader := func(data []byte) *wire.Header { - hdr, err := wire.ParseHeader(bytes.NewReader(data), 0) + hdr, _, _, err := wire.ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) return hdr } @@ -239,8 +239,7 @@ var _ = Describe("Server", func() { var write mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&write)) Expect(write.to.String()).To(Equal("127.0.0.1:1337")) - hdr, err := wire.ParseHeader(bytes.NewReader(write.data), 0) - Expect(err).ToNot(HaveOccurred()) + hdr := parseHeader(write.data) Expect(hdr.IsVersionNegotiation()).To(BeTrue()) Expect(hdr.DestConnectionID).To(Equal(srcConnID)) Expect(hdr.SrcConnectionID).To(Equal(destConnID)) @@ -370,8 +369,7 @@ var _ = Describe("Server", func() { var reject mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&reject)) Expect(reject.to).To(Equal(senderAddr)) - rejectHdr, err := wire.ParseHeader(bytes.NewReader(reject.data), 0) - Expect(err).ToNot(HaveOccurred()) + rejectHdr := parseHeader(reject.data) Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) Expect(rejectHdr.Version).To(Equal(hdr.Version)) Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) diff --git a/session_test.go b/session_test.go index ce014baaa..0a78cca98 100644 --- a/session_test.go +++ b/session_test.go @@ -496,7 +496,7 @@ var _ = Describe("Session", func() { buf := &bytes.Buffer{} Expect(extHdr.Write(buf, sess.version)).To(Succeed()) // need to set extHdr.Header, since the wire.Header contains the parsed length - hdr, err := wire.ParseHeader(bytes.NewReader(buf.Bytes()), 0) + hdr, _, _, err := wire.ParsePacket(buf.Bytes(), 0) Expect(err).ToNot(HaveOccurred()) extHdr.Header = *hdr return buf.Bytes() From 14426dfa1239022c8e6f63b6679b1ea5108521e2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 15 Jan 2019 16:16:42 +0700 Subject: [PATCH 2/4] implement a function to parse the destination connection ID of a packet --- internal/wire/header.go | 24 ++++++++++++ internal/wire/header_test.go | 75 ++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/internal/wire/header.go b/internal/wire/header.go index 376b362a0..9461cccb3 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -11,6 +11,30 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) +// ParseConnectionID parses the destination connection ID of a packet. +// It uses the data slice for the connection ID. +// That means that the connection ID must not be used after the packet buffer is released. +func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { + if len(data) == 0 { + return nil, io.EOF + } + isLongHeader := data[0]&0x80 > 0 + if !isLongHeader { + if len(data) < shortHeaderConnIDLen+1 { + return nil, io.EOF + } + return protocol.ConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil + } + if len(data) < 6 { + return nil, io.EOF + } + destConnIDLen, _ := decodeConnIDLen(data[5]) + if len(data) < 6+destConnIDLen { + return nil, io.EOF + } + return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil +} + var errUnsupportedVersion = errors.New("unsupported version") // The Header is the version independent part of the header diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index eb6202508..b1c8702f4 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -19,6 +19,81 @@ var _ = Describe("Header Parsing", func() { return data } + Context("Parsing the Connection ID", func() { + It("parses the connection ID of a long header packet", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, + Version: versionIETFFrames, + }, + PacketNumberLen: 2, + }).Write(buf, versionIETFFrames)).To(Succeed()) + connID, err := ParseConnectionID(buf.Bytes(), 8) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + }) + + It("parses the connection ID of a short header packet", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + }, + PacketNumberLen: 2, + }).Write(buf, versionIETFFrames)).To(Succeed()) + buf.Write([]byte("foobar")) + connID, err := ParseConnectionID(buf.Bytes(), 4) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + }) + + It("errors on EOF, for short header packets", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + }, + PacketNumberLen: 2, + }).Write(buf, versionIETFFrames)).To(Succeed()) + data := buf.Bytes()[:buf.Len()-2] // cut the packet number + _, err := ParseConnectionID(data, 8) + Expect(err).ToNot(HaveOccurred()) + for i := 0; i < len(data); i++ { + b := make([]byte, i) + copy(b, data[:i]) + _, err := ParseConnectionID(b, 8) + Expect(err).To(MatchError(io.EOF)) + } + }) + + It("errors on EOF, for long header packets", func() { + buf := &bytes.Buffer{} + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 8, 9}, + Version: versionIETFFrames, + }, + PacketNumberLen: 2, + }).Write(buf, versionIETFFrames)).To(Succeed()) + data := buf.Bytes()[:buf.Len()-2] // cut the packet number + _, err := ParseConnectionID(data, 8) + Expect(err).ToNot(HaveOccurred()) + for i := 0; i < 1 /* first byte */ +4 /* version */ +1 /* conn ID lengths */ +6; /* dest conn ID */ i++ { + b := make([]byte, i) + copy(b, data[:i]) + _, err := ParseConnectionID(b, 8) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + Context("Version Negotiation Packets", func() { It("parses", func() { srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} From df34e4496e3224f171d4ae28685a62ba32d80020 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 4 Feb 2019 16:20:50 +0800 Subject: [PATCH 3/4] identify version negotiation packets without parsing the header --- client.go | 2 +- client_test.go | 3 ++- internal/wire/header.go | 15 +++++++++------ internal/wire/header_test.go | 25 ++++++++++++++++++++++--- server_test.go | 2 +- 5 files changed, 35 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index 3eff5aca4..84ad19ace 100644 --- a/client.go +++ b/client.go @@ -287,7 +287,7 @@ func (c *client) establishSecureConnection(ctx context.Context) error { } func (c *client) handlePacket(p *receivedPacket) { - if p.hdr.IsVersionNegotiation() { + if wire.IsVersionNegotiationPacket(p.data) { go c.handleVersionNegotiationPacket(p.hdr) return } diff --git a/client_test.go b/client_test.go index d0df05fa3..b35817532 100644 --- a/client_test.go +++ b/client_test.go @@ -60,10 +60,11 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) hdr, _, _, err := wire.ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.IsVersionNegotiation()).To(BeTrue()) + Expect(wire.IsVersionNegotiationPacket(data)).To(BeTrue()) return &receivedPacket{ rcvTime: time.Now(), hdr: hdr, + data: data, } } diff --git a/internal/wire/header.go b/internal/wire/header.go index 9461cccb3..740d37899 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -35,6 +35,14 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil } +// IsVersionNegotiationPacket says if this is a version negotiation packet +func IsVersionNegotiationPacket(b []byte) bool { + if len(b) < 5 { + return false + } + return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 +} + var errUnsupportedVersion = errors.New("unsupported version") // The Header is the version independent part of the header @@ -129,7 +137,7 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error { return err } h.Version = protocol.VersionNumber(v) - if !h.IsVersionNegotiation() && h.typeByte&0x40 == 0 { + if h.Version != 0 && h.typeByte&0x40 == 0 { return errors.New("not a QUIC packet") } connIDLenByte, err := b.ReadByte() @@ -214,11 +222,6 @@ func (h *Header) parseVersionNegotiationPacket(b *bytes.Reader) error { return nil } -// IsVersionNegotiation says if this a version negotiation packet -func (h *Header) IsVersionNegotiation() bool { - return h.IsLongHeader && h.Version == 0 -} - // ParsedLen returns the number of bytes that were consumed when parsing the header func (h *Header) ParsedLen() protocol.ByteCount { return h.parsedLen diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index b1c8702f4..363c50151 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -94,6 +94,24 @@ var _ = Describe("Header Parsing", func() { }) }) + Context("Identifying Version Negotiation Packets", func() { + It("identifies version negotiation packets", func() { + Expect(IsVersionNegotiationPacket([]byte{0x80 | 0x56, 0, 0, 0, 0})).To(BeTrue()) + Expect(IsVersionNegotiationPacket([]byte{0x56, 0, 0, 0, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 1, 0, 0, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 1, 0, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 0, 1, 0})).To(BeFalse()) + Expect(IsVersionNegotiationPacket([]byte{0x80, 0, 0, 0, 1})).To(BeFalse()) + }) + + It("returns false on EOF", func() { + vnp := []byte{0x80, 0, 0, 0, 0} + for i := range vnp { + Expect(IsVersionNegotiationPacket(vnp[:i])).To(BeFalse()) + } + }) + }) + Context("Version Negotiation Packets", func() { It("parses", func() { srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} @@ -101,12 +119,12 @@ var _ = Describe("Header Parsing", func() { versions := []protocol.VersionNumber{0x22334455, 0x33445566} vnp, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(err).ToNot(HaveOccurred()) + Expect(IsVersionNegotiationPacket(vnp)).To(BeTrue()) hdr, _, rest, err := ParsePacket(vnp, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.DestConnectionID).To(Equal(destConnID)) Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) Expect(hdr.IsLongHeader).To(BeTrue()) - Expect(hdr.IsVersionNegotiation()).To(BeTrue()) Expect(hdr.Version).To(BeZero()) for _, v := range versions { Expect(hdr.SupportedVersions).To(ContainElement(v)) @@ -150,12 +168,12 @@ var _ = Describe("Header Parsing", func() { hdrLen := len(data) data = append(data, []byte{0, 0, 0xbe, 0xef}...) // packet number data = append(data, []byte("foobar")...) + Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) hdr, pdata, rest, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) Expect(pdata).To(Equal(data)) Expect(hdr.IsLongHeader).To(BeTrue()) - Expect(hdr.IsVersionNegotiation()).To(BeFalse()) Expect(hdr.DestConnectionID).To(Equal(destConnID)) Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) @@ -399,10 +417,11 @@ var _ = Describe("Header Parsing", func() { connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} data := append([]byte{0x40}, connID...) data = append(data, 0x42) // packet number + Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) + hdr, pdata, rest, err := ParsePacket(data, 8) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsLongHeader).To(BeFalse()) - Expect(hdr.IsVersionNegotiation()).To(BeFalse()) Expect(hdr.DestConnectionID).To(Equal(connID)) b := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(b, versionIETFFrames) diff --git a/server_test.go b/server_test.go index cfa36cf1d..b5525fff5 100644 --- a/server_test.go +++ b/server_test.go @@ -239,8 +239,8 @@ var _ = Describe("Server", func() { var write mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&write)) Expect(write.to.String()).To(Equal("127.0.0.1:1337")) + Expect(wire.IsVersionNegotiationPacket(write.data)).To(BeTrue()) hdr := parseHeader(write.data) - Expect(hdr.IsVersionNegotiation()).To(BeTrue()) Expect(hdr.DestConnectionID).To(Equal(srcConnID)) Expect(hdr.SrcConnectionID).To(Equal(destConnID)) Expect(hdr.SupportedVersions).ToNot(ContainElement(protocol.VersionNumber(0x42))) From 02e851bd1126cace9be5969884c0dcb4fc42fb34 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 23 Feb 2019 10:16:17 +0800 Subject: [PATCH 4/4] cut coalesed packets in the session --- client.go | 12 +- client_test.go | 16 +- mock_packet_handler_test.go | 24 +-- mock_quic_session_test.go | 12 ++ packet_handler_map.go | 114 ++++---------- packet_handler_map_test.go | 103 ++----------- server.go | 86 ++++++----- server_session.go | 59 -------- server_session_test.go | 78 ---------- server_test.go | 162 ++++++++++---------- session.go | 60 ++++++-- session_test.go | 293 ++++++++++++++++++++++-------------- 12 files changed, 442 insertions(+), 577 deletions(-) delete mode 100644 server_session.go delete mode 100644 server_session_test.go diff --git a/client.go b/client.go index 84ad19ace..54a2231e9 100644 --- a/client.go +++ b/client.go @@ -288,7 +288,7 @@ func (c *client) establishSecureConnection(ctx context.Context) error { func (c *client) handlePacket(p *receivedPacket) { if wire.IsVersionNegotiationPacket(p.data) { - go c.handleVersionNegotiationPacket(p.hdr) + go c.handleVersionNegotiationPacket(p) return } @@ -301,10 +301,16 @@ func (c *client) handlePacket(p *receivedPacket) { c.session.handlePacket(p) } -func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) { +func (c *client) handleVersionNegotiationPacket(p *receivedPacket) { c.mutex.Lock() defer c.mutex.Unlock() + hdr, _, _, err := wire.ParsePacket(p.data, 0) + if err != nil { + c.logger.Debugf("Error parsing Version Negotiation packet: %s", err) + return + } + // ignore delayed / duplicated version negotiation packets if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() { c.logger.Debugf("Received a delayed Version Negotiation packet.") @@ -403,6 +409,6 @@ func (c *client) GetVersion() protocol.VersionNumber { return v } -func (c *client) GetPerspective() protocol.Perspective { +func (c *client) getPerspective() protocol.Perspective { return protocol.PerspectiveClient } diff --git a/client_test.go b/client_test.go index b35817532..89358c853 100644 --- a/client_test.go +++ b/client_test.go @@ -58,12 +58,9 @@ var _ = Describe("Client", func() { composeVersionNegotiationPacket := func(connID protocol.ConnectionID, versions []protocol.VersionNumber) *receivedPacket { data, err := wire.ComposeVersionNegotiation(connID, nil, versions) Expect(err).ToNot(HaveOccurred()) - hdr, _, _, err := wire.ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) Expect(wire.IsVersionNegotiationPacket(data)).To(BeTrue()) return &receivedPacket{ rcvTime: time.Now(), - hdr: hdr, data: data, } } @@ -543,19 +540,22 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError(testErr)) }) - It("recognizes that a non version negotiation packet means that the server accepted the suggested version", func() { + It("recognizes that a non Version Negotiation packet means that the server accepted the suggested version", func() { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().handlePacket(gomock.Any()) cl.session = sess cl.config = &Config{} - cl.handlePacket(&receivedPacket{ - hdr: &wire.Header{ + buf := &bytes.Buffer{} + Expect((&wire.ExtendedHeader{ + Header: wire.Header{ DestConnectionID: connID, SrcConnectionID: connID, Version: cl.version, }, - }) - Eventually(cl.versionNegotiated.Get()).Should(BeTrue()) + PacketNumberLen: protocol.PacketNumberLen3, + }).Write(buf, protocol.VersionTLS)).To(Succeed()) + cl.handlePacket(&receivedPacket{data: buf.Bytes()}) + Eventually(cl.versionNegotiated.Get).Should(BeTrue()) }) It("errors if no matching version is found", func() { diff --git a/mock_packet_handler_test.go b/mock_packet_handler_test.go index 4425bcb9f..812e7dccb 100644 --- a/mock_packet_handler_test.go +++ b/mock_packet_handler_test.go @@ -46,18 +46,6 @@ func (mr *MockPacketHandlerMockRecorder) Close() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandler)(nil).Close)) } -// GetPerspective mocks base method -func (m *MockPacketHandler) GetPerspective() protocol.Perspective { - ret := m.ctrl.Call(m, "GetPerspective") - ret0, _ := ret[0].(protocol.Perspective) - return ret0 -} - -// GetPerspective indicates an expected call of GetPerspective -func (mr *MockPacketHandlerMockRecorder) GetPerspective() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPerspective", reflect.TypeOf((*MockPacketHandler)(nil).GetPerspective)) -} - // destroy mocks base method func (m *MockPacketHandler) destroy(arg0 error) { m.ctrl.Call(m, "destroy", arg0) @@ -68,6 +56,18 @@ func (mr *MockPacketHandlerMockRecorder) destroy(arg0 interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0) } +// getPerspective mocks base method +func (m *MockPacketHandler) getPerspective() protocol.Perspective { + ret := m.ctrl.Call(m, "getPerspective") + ret0, _ := ret[0].(protocol.Perspective) + return ret0 +} + +// getPerspective indicates an expected call of getPerspective +func (mr *MockPacketHandlerMockRecorder) getPerspective() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockPacketHandler)(nil).getPerspective)) +} + // handlePacket mocks base method func (m *MockPacketHandler) handlePacket(arg0 *receivedPacket) { m.ctrl.Call(m, "handlePacket", arg0) diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index 226d3beb3..201b1495c 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -231,6 +231,18 @@ func (mr *MockQuicSessionMockRecorder) destroy(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQuicSession)(nil).destroy), arg0) } +// getPerspective mocks base method +func (m *MockQuicSession) getPerspective() protocol.Perspective { + ret := m.ctrl.Call(m, "getPerspective") + ret0, _ := ret[0].(protocol.Perspective) + return ret0 +} + +// getPerspective indicates an expected call of getPerspective +func (mr *MockQuicSessionMockRecorder) getPerspective() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockQuicSession)(nil).getPerspective)) +} + // handlePacket mocks base method func (m *MockQuicSession) handlePacket(arg0 *receivedPacket) { m.ctrl.Call(m, "handlePacket", arg0) diff --git a/packet_handler_map.go b/packet_handler_map.go index eb79316c0..f3e1d019b 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -2,7 +2,6 @@ package quic import ( "errors" - "fmt" "net" "sync" "time" @@ -105,7 +104,7 @@ func (h *packetHandlerMap) CloseServer() { var wg sync.WaitGroup for id, handlerEntry := range h.handlers { handler := handlerEntry.handler - if handler.GetPerspective() == protocol.PerspectiveServer { + if handler.getPerspective() == protocol.PerspectiveServer { wg.Add(1) go func(id string, handler packetHandler) { // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped @@ -174,93 +173,46 @@ func (h *packetHandlerMap) handlePacket( buffer *packetBuffer, data []byte, ) { - packets, err := h.parsePacket(addr, buffer, data) + connID, err := wire.ParseConnectionID(data, h.connIDLen) if err != nil { - h.logger.Debugf("error parsing packets from %s: %s", addr, err) - // This is just the error from parsing the last packet. - // We still need to process the packets that were successfully parsed before. - } - if len(packets) == 0 { - buffer.Release() + h.logger.Debugf("error parsing connection ID on packet from %s: %s", addr, err) return } - h.handleParsedPackets(packets) -} - -func (h *packetHandlerMap) parsePacket( - addr net.Addr, - buffer *packetBuffer, - data []byte, -) ([]*receivedPacket, error) { rcvTime := time.Now() - packets := make([]*receivedPacket, 0, 1) - var counter int - var lastConnID protocol.ConnectionID - for len(data) > 0 { - hdr, packetData, rest, err := wire.ParsePacket(data, h.connIDLen) - if err != nil { - return packets, fmt.Errorf("error parsing packet: %s", err) - } - - if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { - return packets, fmt.Errorf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID) - } - lastConnID = hdr.DestConnectionID - - if counter > 0 { - buffer.Split() - } - counter++ - packets = append(packets, &receivedPacket{ - remoteAddr: addr, - hdr: hdr, - rcvTime: rcvTime, - data: packetData, - buffer: buffer, - }) - - // only log if this actually a coalesced packet - if h.logger.Debug() && (counter > 1 || len(rest) > 0) { - h.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packets[counter-1].data), len(rest)) - } - - data = rest - } - return packets, nil -} - -func (h *packetHandlerMap) handleParsedPackets(packets []*receivedPacket) { h.mutex.RLock() defer h.mutex.RUnlock() - // coalesced packets all have the same destination connection ID - handlerEntry, handlerFound := h.handlers[string(packets[0].hdr.DestConnectionID)] + handlerEntry, handlerFound := h.handlers[string(connID)] - for _, p := range packets { - if handlerFound { // existing session - handlerEntry.handler.handlePacket(p) - continue - } - // No session found. - // This might be a stateless reset. - if !p.hdr.IsLongHeader { - if len(p.data) >= protocol.MinStatelessResetSize { - var token [16]byte - copy(token[:], p.data[len(p.data)-16:]) - if sess, ok := h.resetTokens[token]; ok { - sess.destroy(errors.New("received a stateless reset")) - continue - } - } - // TODO(#943): send a stateless reset - h.logger.Debugf("received a short header packet with an unexpected connection ID %s", p.hdr.DestConnectionID) - break // a short header packet is always the last in a coalesced packet - } - if h.server == nil { // no server set - h.logger.Debugf("received a packet with an unexpected connection ID %s", p.hdr.DestConnectionID) - continue - } - h.server.handlePacket(p) + p := &receivedPacket{ + remoteAddr: addr, + rcvTime: rcvTime, + buffer: buffer, + data: data, } + if handlerFound { // existing session + handlerEntry.handler.handlePacket(p) + return + } + // No session found. + // This might be a stateless reset. + if data[0]&0x80 == 0 { // stateless resets are always short header packets + if len(p.data) >= protocol.MinStatelessResetSize { + var token [16]byte + copy(token[:], p.data[len(p.data)-16:]) + if sess, ok := h.resetTokens[token]; ok { + sess.destroy(errors.New("received a stateless reset")) + return + } + } + // TODO(#943): send a stateless reset + h.logger.Debugf("received a short header packet with an unexpected connection ID %s", connID) + return + } + if h.server == nil { // no server set + h.logger.Debugf("received a packet with an unexpected connection ID %s", connID) + return + } + h.server.handlePacket(p) } diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 8f0b4e8cb..2f1e03a88 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -3,7 +3,6 @@ package quic import ( "bytes" "errors" - "net" "time" "github.com/golang/mock/gomock" @@ -88,11 +87,15 @@ var _ = Describe("Packet Handler Map", func() { handledPacket1 := make(chan struct{}) handledPacket2 := make(chan struct{}) packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.hdr.DestConnectionID).To(Equal(connID1)) + connID, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(connID1)) close(handledPacket1) }) packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.hdr.DestConnectionID).To(Equal(connID2)) + connID, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(connID2)) close(handledPacket2) }) handler.Add(connID1, packetHandler1) @@ -105,12 +108,10 @@ var _ = Describe("Packet Handler Map", func() { }) It("drops unparseable packets", func() { - _, err := handler.parsePacket(nil, nil, []byte{0, 1, 2, 3}) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("error parsing packet:")) + handler.handlePacket(nil, nil, []byte{0, 1, 2, 3}) }) - It("deletes removed session immediately", func() { + It("deletes removed sessions immediately", func() { handler.deleteRetiredSessionsAfter = time.Hour connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} handler.Add(connID, NewMockPacketHandler(mockCtrl)) @@ -159,64 +160,6 @@ var _ = Describe("Packet Handler Map", func() { conn.Close() Eventually(done).Should(BeClosed()) }) - - Context("coalesced packets", func() { - It("cuts packets to the right length", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - data := append(getPacketWithLength(connID, 456), make([]byte, 1000)...) - packetHandler := NewMockPacketHandler(mockCtrl) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.data).To(HaveLen(456 + int(p.hdr.ParsedLen()))) - }) - handler.Add(connID, packetHandler) - handler.handlePacket(nil, nil, data) - }) - - It("handles coalesced packets", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - packetHandler := NewMockPacketHandler(mockCtrl) - handledPackets := make(chan *receivedPacket, 3) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - handledPackets <- p - }).Times(3) - handler.Add(connID, packetHandler) - - buffer := getPacketBuffer() - packet := buffer.Slice[:0] - packet = append(packet, append(getPacketWithLength(connID, 10), make([]byte, 10-2 /* packet number len */)...)...) - packet = append(packet, append(getPacketWithLength(connID, 20), make([]byte, 20-2 /* packet number len */)...)...) - packet = append(packet, append(getPacketWithLength(connID, 30), make([]byte, 30-2 /* packet number len */)...)...) - conn.dataToRead <- packet - - now := time.Now() - for i := 1; i <= 3; i++ { - var p *receivedPacket - Eventually(handledPackets).Should(Receive(&p)) - Expect(p.hdr.DestConnectionID).To(Equal(connID)) - Expect(p.hdr.Length).To(BeEquivalentTo(10 * i)) - Expect(p.data).To(HaveLen(int(p.hdr.ParsedLen() + p.hdr.Length))) - Expect(p.rcvTime).To(BeTemporally("~", now, scaleDuration(20*time.Millisecond))) - Expect(p.buffer.refCount).To(Equal(3)) - } - }) - - It("ignores coalesced packet parts if the connection IDs don't match", func() { - connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - - buffer := getPacketBuffer() - packet := buffer.Slice[:0] - // var packet []byte - packet = append(packet, getPacket(connID1)...) - packet = append(packet, getPacket(connID2)...) - - packets, err := handler.parsePacket(&net.UDPAddr{}, buffer, packet) - Expect(err).To(MatchError("coalesced packet has different destination connection ID: 0x0807060504030201, expected 0x0102030405060708")) - Expect(packets).To(HaveLen(1)) - Expect(packets[0].hdr.DestConnectionID).To(Equal(connID1)) - Expect(packets[0].buffer.refCount).To(Equal(1)) - }) - }) }) Context("stateless reset handling", func() { @@ -228,7 +171,9 @@ var _ = Describe("Packet Handler Map", func() { // first send a normal packet handledPacket := make(chan struct{}) packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.hdr.DestConnectionID).To(Equal(connID)) + cid, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(cid).To(Equal(connID)) close(handledPacket) }) conn.dataToRead <- getPacket(connID) @@ -250,24 +195,6 @@ var _ = Describe("Packet Handler Map", func() { Eventually(destroyed).Should(BeClosed()) }) - It("detects a stateless that is coalesced with another packet", func() { - packetHandler := NewMockPacketHandler(mockCtrl) - connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} - token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddWithResetToken(connID, packetHandler, token) - fakeConnID := protocol.ConnectionID{1, 2, 3, 4, 5} - packet := getPacket(fakeConnID) - reset := append([]byte{0x40} /* short header packet */, fakeConnID...) - reset = append(reset, make([]byte, 50)...) // add some "random" data - reset = append(reset, token[:]...) - destroyed := make(chan struct{}) - packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { - close(destroyed) - }) - conn.dataToRead <- append(packet, reset...) - Eventually(destroyed).Should(BeClosed()) - }) - It("deletes reset tokens when the session is retired", func() { handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} @@ -291,7 +218,9 @@ var _ = Describe("Packet Handler Map", func() { p := getPacket(connID) server := NewMockUnknownPacketHandler(mockCtrl) server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.hdr.DestConnectionID).To(Equal(connID)) + cid, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(cid).To(Equal(connID)) }) handler.SetServer(server) handler.handlePacket(nil, nil, p) @@ -299,9 +228,9 @@ var _ = Describe("Packet Handler Map", func() { It("closes all server sessions", func() { clientSess := NewMockPacketHandler(mockCtrl) - clientSess.EXPECT().GetPerspective().Return(protocol.PerspectiveClient) + clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient) serverSess := NewMockPacketHandler(mockCtrl) - serverSess.EXPECT().GetPerspective().Return(protocol.PerspectiveServer) + serverSess.EXPECT().getPerspective().Return(protocol.PerspectiveServer) serverSess.EXPECT().Close() handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess) diff --git a/server.go b/server.go index de6fc86b1..eba17ac1d 100644 --- a/server.go +++ b/server.go @@ -23,7 +23,7 @@ type packetHandler interface { handlePacket(*receivedPacket) io.Closer destroy(error) - GetPerspective() protocol.Perspective + getPerspective() protocol.Perspective } type unknownPacketHandler interface { @@ -44,6 +44,7 @@ type quicSession interface { Session handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber + getPerspective() protocol.Perspective run() error destroy(error) closeForRecreating() protocol.PacketNumber @@ -324,53 +325,60 @@ func (s *server) Addr() net.Addr { } func (s *server) handlePacket(p *receivedPacket) { - hdr := p.hdr - - // send a Version Negotiation Packet if the client is speaking a different protocol version - if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - go s.sendVersionNegotiationPacket(p) - return - } - if hdr.Type == protocol.PacketTypeInitial { - go s.handleInitial(p) - return - } - - defer p.buffer.Release() - // Drop long header packets. - // There's litte point in sending a Stateless Reset, since the client - // might not have received the token yet. - if hdr.IsLongHeader { - return - } - + go func() { + if shouldReleaseBuffer := s.handlePacketImpl(p); !shouldReleaseBuffer { + p.buffer.Release() + } + }() } -func (s *server) handleInitial(p *receivedPacket) { - s.logger.Debugf("<- Received Initial packet.") - sess, connID, err := s.handleInitialImpl(p) +func (s *server) handlePacketImpl(p *receivedPacket) bool /* was the packet passed on to a session */ { + if len(p.data) < protocol.MinInitialPacketSize { + s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", len(p.data)) + return false + } + // If we're creating a new session, the packet will be passed to the session. + // The header will then be parsed again. + hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength) + if err != nil { + s.logger.Debugf("Error parsing packet: %s", err) + return false + } + if !hdr.IsLongHeader { + // TODO: send a stateless reset + return false + } + // send a Version Negotiation Packet if the client is speaking a different protocol version + if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { + s.sendVersionNegotiationPacket(p, hdr) + return false + } + if hdr.IsLongHeader && hdr.Type != protocol.PacketTypeInitial { + // Drop long header packets. + // There's litte point in sending a Stateless Reset, since the client + // might not have received the token yet. + return false + } + + s.logger.Debugf("<- Received Initial packet.") + + sess, connID, err := s.handleInitialImpl(p, hdr) if err != nil { - p.buffer.Release() s.logger.Errorf("Error occurred handling initial packet: %s", err) - return + return false } if sess == nil { // a retry was done, or the connection attempt was rejected - p.buffer.Release() - return + return false } // Don't put the packet buffer back if a new session was created. // The session will handle the packet and take of that. - serverSession := newServerSession(sess, s.config, s.logger) - s.sessionHandler.Add(connID, serverSession) + s.sessionHandler.Add(connID, sess) + return true } -func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) { - hdr := p.hdr +func (s *server) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, protocol.ConnectionID, error) { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { - return nil, nil, errors.New("dropping Initial packet with too short connection ID") - } - if len(p.data) < protocol.MinInitialPacketSize { - return nil, nil, errors.New("dropping too small Initial packet") + return nil, nil, errors.New("too short connection ID") } var cookie *Cookie @@ -388,7 +396,7 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con if !s.config.AcceptCookie(p.remoteAddr, cookie) { // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the session. - (&wire.ExtendedHeader{Header: *p.hdr}).Log(s.logger) + (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) return nil, nil, s.sendRetry(p.remoteAddr, hdr) } @@ -535,9 +543,7 @@ func (s *server) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error { return nil } -func (s *server) sendVersionNegotiationPacket(p *receivedPacket) { - defer p.buffer.Release() - hdr := p.hdr +func (s *server) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) { s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) if err != nil { diff --git a/server_session.go b/server_session.go deleted file mode 100644 index d1ab73a49..000000000 --- a/server_session.go +++ /dev/null @@ -1,59 +0,0 @@ -package quic - -import ( - "fmt" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" -) - -type serverSession struct { - quicSession - - config *Config - - logger utils.Logger -} - -var _ packetHandler = &serverSession{} - -func newServerSession(sess quicSession, config *Config, logger utils.Logger) packetHandler { - return &serverSession{ - quicSession: sess, - config: config, - logger: logger, - } -} - -func (s *serverSession) handlePacket(p *receivedPacket) { - if err := s.handlePacketImpl(p); err != nil { - s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err) - } -} - -func (s *serverSession) handlePacketImpl(p *receivedPacket) error { - hdr := p.hdr - - // Probably an old packet that was sent by the client before the version was negotiated. - // It is safe to drop it. - if hdr.IsLongHeader && hdr.Version != s.quicSession.GetVersion() { - return nil - } - - if hdr.IsLongHeader { - switch hdr.Type { - case protocol.PacketTypeInitial, protocol.PacketTypeHandshake: - // nothing to do here. Packet will be passed to the session. - default: - // Note that this also drops 0-RTT packets. - return fmt.Errorf("Received unsupported packet type: %s", hdr.Type) - } - } - - s.quicSession.handlePacket(p) - return nil -} - -func (s *serverSession) GetPerspective() protocol.Perspective { - return protocol.PerspectiveServer -} diff --git a/server_session_test.go b/server_session_test.go deleted file mode 100644 index b350eb517..000000000 --- a/server_session_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package quic - -import ( - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Server Session", func() { - var ( - qsess *MockQuicSession - sess *serverSession - ) - - BeforeEach(func() { - qsess = NewMockQuicSession(mockCtrl) - sess = newServerSession(qsess, &Config{}, utils.DefaultLogger).(*serverSession) - }) - - It("handles packets", func() { - p := &receivedPacket{ - hdr: &wire.Header{ - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}, - }, - } - qsess.EXPECT().handlePacket(p) - sess.handlePacket(p) - }) - - It("ignores delayed packets with mismatching versions", func() { - qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100)) - // don't EXPECT any calls to handlePacket() - p := &receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - Version: protocol.VersionNumber(123), - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - }, - } - err := sess.handlePacketImpl(p) - Expect(err).ToNot(HaveOccurred()) - }) - - It("ignores packets with the wrong Long Header type", func() { - qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100)) - p := &receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - Version: protocol.VersionNumber(100), - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - }, - } - err := sess.handlePacketImpl(p) - Expect(err).To(MatchError("Received unsupported packet type: Retry")) - }) - - It("passes on Handshake packets", func() { - p := &receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Version: protocol.VersionNumber(100), - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - }, - } - qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100)) - qsess.EXPECT().handlePacket(p) - Expect(sess.handlePacketImpl(p)).To(Succeed()) - }) - - It("has the right perspective", func() { - Expect(sess.GetPerspective()).To(Equal(protocol.PerspectiveServer)) - }) -}) diff --git a/server_test.go b/server_test.go index b5525fff5..ca8a8801c 100644 --- a/server_test.go +++ b/server_test.go @@ -26,6 +26,18 @@ var _ = Describe("Server", func() { tlsConf *tls.Config ) + getPacket := func(hdr *wire.Header, data []byte) *receivedPacket { + buf := &bytes.Buffer{} + Expect((&wire.ExtendedHeader{ + Header: *hdr, + PacketNumberLen: protocol.PacketNumberLen3, + }).Write(buf, protocol.VersionTLS)).To(Succeed()) + return &receivedPacket{ + data: append(buf.Bytes(), data...), + buffer: getPacketBuffer(), + } + } + BeforeEach(func() { conn = newMockPacketConn() conn.addr = &net.UDPAddr{} @@ -124,53 +136,45 @@ var _ = Describe("Server", func() { } It("drops Initial packets with a too short connection ID", func() { - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Version: serv.config.Versions[0], - }, - })) + serv.handlePacket(getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Version: serv.config.Versions[0], + }, nil)) Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops too small Initial", func() { - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - Version: serv.config.Versions[0], - }, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100), - })) + serv.handlePacket(getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize-100), + )) Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops packets with a too short connection ID", func() { - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Version: serv.config.Versions[0], - }, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - })) + serv.handlePacket(getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize))) Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops non-Initial packets", func() { - serv.logger.SetLogLevel(utils.LogLevelDebug) - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ + serv.handlePacket(getPacket( + &wire.Header{ Type: protocol.PacketTypeHandshake, Version: serv.config.Versions[0], }, - data: []byte("invalid"), - })) + []byte("invalid"), + )) }) It("decodes the cookie from the Token field", func() { @@ -187,15 +191,14 @@ var _ = Describe("Server", func() { } token, err := serv.cookieGenerator.NewToken(raddr, nil) Expect(err).ToNot(HaveOccurred()) - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - remoteAddr: raddr, - hdr: &wire.Header{ - Type: protocol.PacketTypeInitial, - Token: token, - Version: serv.config.Versions[0], - }, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - })) + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Token: token, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) @@ -211,31 +214,29 @@ var _ = Describe("Server", func() { close(done) return false } - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - remoteAddr: raddr, - hdr: &wire.Header{ - Type: protocol.PacketTypeInitial, - Token: []byte("foobar"), - Version: serv.config.Versions[0], - }, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - })) + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Token: []byte("foobar"), + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) It("sends a Version Negotiation Packet for unsupported versions", func() { srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, - hdr: &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: srcConnID, - DestConnectionID: destConnID, - Version: 0x42, - }, - })) + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: srcConnID, + DestConnectionID: destConnID, + Version: 0x42, + }, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + serv.handlePacket(packet) var write mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&write)) Expect(write.to.String()).To(Equal("127.0.0.1:1337")) @@ -249,16 +250,15 @@ var _ = Describe("Server", func() { It("replies with a Retry packet, if a Cookie is required", func() { serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return false } hdr := &wire.Header{ + IsLongHeader: true, Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, Version: protocol.VersionTLS, } - serv.handleInitial(insertPacketBuffer(&receivedPacket{ - remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, - hdr: hdr, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - })) + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + serv.handlePacket(packet) var write mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&write)) Expect(write.to.String()).To(Equal("127.0.0.1:1337")) @@ -273,15 +273,13 @@ var _ = Describe("Server", func() { It("creates a session, if no Cookie is required", func() { serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return true } hdr := &wire.Header{ + IsLongHeader: true, Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, Version: protocol.VersionTLS, } - p := &receivedPacket{ - hdr: hdr, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) run := make(chan struct{}) serv.newSession = func( _ connection, @@ -309,7 +307,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - serv.handlePacket(insertPacketBuffer(p)) + serv.handlePacket(p) // the Handshake packet is written by the session Consistently(conn.dataWritten).ShouldNot(Receive()) close(done) @@ -324,16 +322,14 @@ var _ = Describe("Server", func() { senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} hdr := &wire.Header{ + IsLongHeader: true, Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, Version: protocol.VersionTLS, } - p := &receivedPacket{ - remoteAddr: senderAddr, - hdr: hdr, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + p.remoteAddr = senderAddr serv.newSession = func( _ connection, runner sessionRunner, @@ -360,12 +356,12 @@ var _ = Describe("Server", func() { go func() { defer GinkgoRecover() defer wg.Done() - serv.handlePacket(insertPacketBuffer(p)) + serv.handlePacket(p) Consistently(conn.dataWritten).ShouldNot(Receive()) }() } wg.Wait() - serv.handlePacket(insertPacketBuffer(p)) + serv.handlePacket(p) var reject mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&reject)) Expect(reject.to).To(Equal(senderAddr)) @@ -381,16 +377,14 @@ var _ = Describe("Server", func() { senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} hdr := &wire.Header{ + IsLongHeader: true, Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, Version: protocol.VersionTLS, } - p := &receivedPacket{ - remoteAddr: senderAddr, - hdr: hdr, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + p.remoteAddr = senderAddr ctx, cancel := context.WithCancel(context.Background()) sessionCreated := make(chan struct{}) sess := NewMockQuicSession(mockCtrl) @@ -414,7 +408,7 @@ var _ = Describe("Server", func() { return sess, nil } - serv.handlePacket(insertPacketBuffer(p)) + serv.handlePacket(p) Consistently(conn.dataWritten).ShouldNot(Receive()) Eventually(sessionCreated).Should(BeClosed()) cancel() @@ -429,7 +423,7 @@ var _ = Describe("Server", func() { Consistently(done).ShouldNot(BeClosed()) // make the go routine return - sess.EXPECT().Close() + sess.EXPECT().getPerspective() Expect(serv.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) diff --git a/session.go b/session.go index 80448a6f6..feaadc29b 100644 --- a/session.go +++ b/session.go @@ -55,7 +55,6 @@ type cryptoStreamHandler interface { type receivedPacket struct { remoteAddr net.Addr - hdr *wire.Header rcvTime time.Time data []byte @@ -483,7 +482,43 @@ func (s *session) handleHandshakeComplete() { } } -func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet successfully processed */ { +func (s *session) handlePacketImpl(p *receivedPacket) bool { + var counter uint8 + var lastConnID protocol.ConnectionID + var processed bool + for len(p.data) > 0 { + hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnID.Len()) + if err != nil { + s.logger.Debugf("error parsing packet: %s", err) + break + } + + if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { + s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID) + break + } + lastConnID = hdr.DestConnectionID + + if counter > 0 { + p.buffer.Split() + } + counter++ + + // only log if this actually a coalesced packet + if s.logger.Debug() && (counter > 1 || len(rest) > 0) { + s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest)) + } + p.data = packetData + pr := s.handleSinglePacket(p, hdr) + if pr { + processed = pr + } + p.data = rest + } + return processed +} + +func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { var wasQueued bool defer func() { @@ -493,22 +528,22 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc } }() - if p.hdr.Type == protocol.PacketTypeRetry { - return s.handleRetryPacket(p) + if hdr.Type == protocol.PacketTypeRetry { + return s.handleRetryPacket(p, hdr) } // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. - if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) { - s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.hdr.SrcConnectionID, s.destConnID) + if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) { + s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", hdr.SrcConnectionID, s.destConnID) return false } // drop 0-RTT packets - if p.hdr.Type == protocol.PacketType0RTT { + if hdr.Type == protocol.PacketType0RTT { return false } - packet, err := s.unpacker.Unpack(p.hdr, p.data) + packet, err := s.unpacker.Unpack(hdr, p.data) if err != nil { if err == handshake.ErrOpenerNotYetAvailable { // Sealer for this encryption level not yet available. @@ -524,7 +559,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc } if s.logger.Debug() { - s.logger.Debugf("<- Reading packet %#x (%d bytes) for connection %s, %s", packet.packetNumber, len(p.data), p.hdr.DestConnectionID, packet.encryptionLevel) + s.logger.Debugf("<- Reading packet %#x (%d bytes) for connection %s, %s", packet.packetNumber, len(p.data), hdr.DestConnectionID, packet.encryptionLevel) packet.hdr.Log(s.logger) } @@ -535,7 +570,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc return true } -func (s *session) handleRetryPacket(p *receivedPacket) bool /* was this a valid Retry */ { +func (s *session) handleRetryPacket(p *receivedPacket, hdr *wire.Header) bool /* was this a valid Retry */ { if s.perspective == protocol.PerspectiveServer { s.logger.Debugf("Ignoring Retry.") return false @@ -544,7 +579,6 @@ func (s *session) handleRetryPacket(p *receivedPacket) bool /* was this a valid s.logger.Debugf("Ignoring Retry, since we already received a packet.") return false } - hdr := p.hdr (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) if !hdr.OrigDestConnectionID.Equal(s.destConnID) { s.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, s.destConnID) @@ -1246,6 +1280,10 @@ func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } +func (s *session) getPerspective() protocol.Perspective { + return s.perspective +} + func (s *session) GetVersion() protocol.VersionNumber { return s.version } diff --git a/session_test.go b/session_test.go index 0a78cca98..ac659f418 100644 --- a/session_test.go +++ b/session_test.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "context" + "crypto/rand" "errors" "net" "runtime/pprof" @@ -354,20 +355,6 @@ var _ = Describe("Session", func() { Expect(str).To(Equal(mstr)) }) - It("drops Retry packets", func() { - hdr := wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - } - buf := &bytes.Buffer{} - (&wire.ExtendedHeader{Header: hdr}).Write(buf, sess.version) - Expect(sess.handlePacketImpl(&receivedPacket{ - hdr: &hdr, - data: buf.Bytes(), - buffer: getPacketBuffer(), - })).To(BeFalse()) - }) - Context("closing", func() { var ( runErr error @@ -492,18 +479,26 @@ var _ = Describe("Session", func() { sess.unpacker = unpacker }) - getData := func(extHdr *wire.ExtendedHeader) []byte { + getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { buf := &bytes.Buffer{} Expect(extHdr.Write(buf, sess.version)).To(Succeed()) - // need to set extHdr.Header, since the wire.Header contains the parsed length - hdr, _, _, err := wire.ParsePacket(buf.Bytes(), 0) - Expect(err).ToNot(HaveOccurred()) - extHdr.Header = *hdr - return buf.Bytes() + return &receivedPacket{ + data: append(buf.Bytes(), data...), + buffer: getPacketBuffer(), + } } + It("drops Retry packets", func() { + hdr := wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + } + Expect(sess.handlePacketImpl(getPacket(&wire.ExtendedHeader{Header: hdr}, nil))).To(BeFalse()) + }) + It("informs the ReceivedPacketHandler about non-retransmittable packets", func() { hdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: sess.srcConnID}, PacketNumber: 0x37, PacketNumberLen: protocol.PacketNumberLen1, } @@ -517,15 +512,14 @@ var _ = Describe("Session", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionInitial, rcvTime, false) sess.receivedPacketHandler = rph - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - rcvTime: rcvTime, - hdr: &hdr.Header, - data: getData(hdr), - }))).To(BeTrue()) + packet := getPacket(hdr, nil) + packet.rcvTime = rcvTime + Expect(sess.handlePacketImpl(packet)).To(BeTrue()) }) It("informs the ReceivedPacketHandler about retransmittable packets", func() { hdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: sess.srcConnID}, PacketNumber: 0x37, PacketNumberLen: protocol.PacketNumberLen1, } @@ -541,11 +535,9 @@ var _ = Describe("Session", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionHandshake, rcvTime, true) sess.receivedPacketHandler = rph - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - rcvTime: rcvTime, - hdr: &hdr.Header, - data: getData(hdr), - }))).To(BeTrue()) + packet := getPacket(hdr, nil) + packet.rcvTime = rcvTime + Expect(sess.handlePacketImpl(packet)).To(BeTrue()) }) It("drops a packet when unpacking fails", func() { @@ -559,10 +551,10 @@ var _ = Describe("Session", func() { sess.run() }() sessionRunner.EXPECT().retireConnectionID(gomock.Any()) - sess.handlePacket(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{}, - data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - })) + sess.handlePacket(getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: sess.srcConnID}, + PacketNumberLen: protocol.PacketNumberLen1, + }, nil)) Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make the go routine return sess.closeLocal(errors.New("close")) @@ -586,65 +578,61 @@ var _ = Describe("Session", func() { close(done) }() sessionRunner.EXPECT().retireConnectionID(gomock.Any()) - sess.handlePacket(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{}, - data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - })) + sess.handlePacket(getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: sess.srcConnID}, + PacketNumberLen: protocol.PacketNumberLen1, + }, nil)) Eventually(done).Should(BeClosed()) }) - It("handles duplicate packets", func() { - hdr := &wire.ExtendedHeader{ - PacketNumber: 5, - PacketNumberLen: protocol.PacketNumberLen1, - } - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ - encryptionLevel: protocol.Encryption1RTT, - hdr: hdr, - data: []byte{0}, // one PADDING frame - }, nil).Times(2) - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue()) - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue()) - }) - It("ignores 0-RTT packets", func() { - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketType0RTT, DestConnectionID: sess.srcConnID, }, - }))).To(BeFalse()) + PacketNumberLen: protocol.PacketNumberLen2, + } + Expect(sess.handlePacketImpl(getPacket(hdr, nil))).To(BeFalse()) }) It("ignores packets with a different source connection ID", func() { - hdr := &wire.Header{ - IsLongHeader: true, - DestConnectionID: sess.destConnID, - SrcConnectionID: sess.srcConnID, - Length: 1, + hdr1 := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: sess.destConnID, + SrcConnectionID: sess.srcConnID, + Length: 1, + Version: sess.version, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 1, } + hdr2 := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: sess.destConnID, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + Length: 1, + Version: sess.version, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 2, + } + Expect(sess.srcConnID).ToNot(Equal(hdr2.SrcConnectionID)) // Send one packet, which might change the connection ID. // only EXPECT one call to the unpacker unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ encryptionLevel: protocol.Encryption1RTT, - hdr: &wire.ExtendedHeader{Header: *hdr}, + hdr: hdr1, data: []byte{0}, // one PADDING frame }, nil) - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - hdr: hdr, - data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - }))).To(BeTrue()) + Expect(sess.handlePacketImpl(getPacket(hdr1, nil))).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - DestConnectionID: sess.destConnID, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - Length: 1, - }, - data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - }))).To(BeFalse()) + Expect(sess.handlePacketImpl(getPacket(hdr2, nil))).To(BeFalse()) }) Context("updating the remote address", func() { @@ -657,14 +645,86 @@ var _ = Describe("Session", func() { origAddr := sess.conn.(*mockConnection).remoteAddr remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} Expect(origAddr).ToNot(Equal(remoteIP)) - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - remoteAddr: remoteIP, - hdr: &wire.Header{}, - data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - }))).To(BeTrue()) + packet := getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: sess.srcConnID}, + PacketNumberLen: protocol.PacketNumberLen1, + }, nil) + packet.remoteAddr = remoteIP + Expect(sess.handlePacketImpl(packet)).To(BeTrue()) Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(origAddr)) }) }) + + Context("coalesced packets", func() { + getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, *receivedPacket) { + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: connID, + SrcConnectionID: sess.destConnID, + Version: protocol.VersionTLS, + Length: length, + }, + PacketNumberLen: protocol.PacketNumberLen3, + } + hdrLen := hdr.GetLength(sess.version) + b := make([]byte, 1) + rand.Read(b) + packet := getPacket(hdr, bytes.Repeat(b, int(length)-3)) + return int(hdrLen), packet + } + + It("cuts packets to the right length", func() { + hdrLen, packet := getPacketWithLength(sess.srcConnID, 456) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(int(hdrLen + 456 - 3))) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + }, nil + }) + Expect(sess.handlePacketImpl(packet)).To(BeTrue()) + }) + + It("handles coalesced packets", func() { + hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(int(hdrLen1 + 456 - 3))) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + }, nil + }) + hdrLen2, packet2 := getPacketWithLength(sess.srcConnID, 123) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(int(hdrLen2 + 123 - 3))) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + }, nil + }) + packet1.data = append(packet1.data, packet2.data...) + Expect(sess.handlePacketImpl(packet1)).To(BeTrue()) + }) + + It("ignores coalesced packet parts if the destination connection IDs don't match", func() { + wrongConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + Expect(sess.srcConnID).ToNot(Equal(wrongConnID)) + hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(int(hdrLen1 + 456 - 3))) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + }, nil + }) + _, packet2 := getPacketWithLength(wrongConnID, 123) + // don't EXPECT any calls to unpacker.Unpack() + packet1.data = append(packet1.data, packet2.data...) + Expect(sess.handlePacketImpl(packet1)).To(BeTrue()) + }) + }) }) Context("sending packets", func() { @@ -1436,6 +1496,15 @@ var _ = Describe("Client Session", func() { cryptoSetup *mocks.MockCryptoSetup ) + getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket { + buf := &bytes.Buffer{} + Expect(hdr.Write(buf, sess.version)).To(Succeed()) + return &receivedPacket{ + data: append(buf.Bytes(), data...), + buffer: getPacketBuffer(), + } + } + BeforeEach(func() { Eventually(areSessionsRunning).Should(BeFalse()) @@ -1450,9 +1519,9 @@ var _ = Describe("Client Session", func() { nil, // tls.Config 42, // initial packet number &handshake.TransportParameters{}, - protocol.VersionWhatever, + protocol.VersionTLS, utils.DefaultLogger, - protocol.VersionWhatever, + protocol.VersionTLS, ) sess = sessP.(*session) Expect(err).ToNot(HaveOccurred()) @@ -1479,16 +1548,16 @@ var _ = Describe("Client Session", func() { }() newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} packer.EXPECT().ChangeDestConnectionID(newConnID) - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ + Expect(sess.handlePacketImpl(getPacket(&wire.ExtendedHeader{ + Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, SrcConnectionID: newConnID, DestConnectionID: sess.srcConnID, Length: 1, }, - data: []byte{0}, - }))).To(BeTrue()) + PacketNumberLen: protocol.PacketNumberLen2, + }, []byte{0}))).To(BeTrue()) // make sure the go routine returns packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sessionRunner.EXPECT().retireConnectionID(gomock.Any()) @@ -1498,56 +1567,52 @@ var _ = Describe("Client Session", func() { }) Context("handling Retry", func() { - var validRetryHdr *wire.Header + var validRetryHdr *wire.ExtendedHeader BeforeEach(func() { - validRetryHdr = &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - OrigDestConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, - Token: []byte("foobar"), + validRetryHdr = &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + OrigDestConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, + Token: []byte("foobar"), + Version: sess.version, + }, } }) - getPacket := func(hdr *wire.Header) *receivedPacket { - buf := &bytes.Buffer{} - (&wire.ExtendedHeader{Header: *hdr}).Write(buf, sess.version) - return &receivedPacket{ - hdr: hdr, - data: buf.Bytes(), - buffer: getPacketBuffer(), - } - } - It("handles Retry packets", func() { cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) packer.EXPECT().SetToken([]byte("foobar")) packer.EXPECT().ChangeDestConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) - Expect(sess.handlePacketImpl(getPacket(validRetryHdr))).To(BeTrue()) + Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeTrue()) }) It("ignores Retry packets after receiving a regular packet", func() { sess.receivedFirstPacket = true - Expect(sess.handlePacketImpl(getPacket(validRetryHdr))).To(BeFalse()) + Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeFalse()) }) It("ignores Retry packets if the server didn't change the connection ID", func() { validRetryHdr.SrcConnectionID = sess.destConnID - Expect(sess.handlePacketImpl(getPacket(validRetryHdr))).To(BeFalse()) + Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeFalse()) }) It("ignores Retry packets with the wrong original destination connection ID", func() { - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - OrigDestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Token: []byte("foobar"), + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + OrigDestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Token: []byte("foobar"), + }, + PacketNumberLen: protocol.PacketNumberLen3, } - Expect(sess.handlePacketImpl(getPacket(hdr))).To(BeFalse()) + Expect(sess.handlePacketImpl(getPacket(hdr, nil))).To(BeFalse()) }) })