From 0bd7e744ffc8c11b16c594ae6cd8ca03e683687d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 1 Jul 2018 12:15:15 +0700 Subject: [PATCH 1/3] implement parsing of headers with different connection ID lengths --- client.go | 2 +- client_multiplexer.go | 2 +- client_test.go | 2 +- internal/protocol/connection_id.go | 2 +- internal/protocol/protocol.go | 6 ++ internal/protocol/server_parameters.go | 9 -- internal/wire/header.go | 5 +- internal/wire/header_parser.go | 13 ++- internal/wire/header_parser_test.go | 113 +++++++++++++--------- internal/wire/header_test.go | 17 +--- internal/wire/version_negotiation_test.go | 4 +- packet_packer_test.go | 2 +- server.go | 2 +- server_test.go | 4 +- server_tls_test.go | 2 +- session_test.go | 2 +- 16 files changed, 97 insertions(+), 90 deletions(-) diff --git a/client.go b/client.go index e8e92ddf6..6937b3441 100644 --- a/client.go +++ b/client.go @@ -364,7 +364,7 @@ func (c *client) handleRead(remoteAddr net.Addr, packet []byte) { rcvTime := time.Now() r := bytes.NewReader(packet) - iHdr, err := wire.ParseInvariantHeader(r) + iHdr, err := wire.ParseInvariantHeader(r, protocol.ConnectionIDLenGQUIC) // drop the packet if we can't parse the header if err != nil { c.logger.Errorf("error parsing invariant header: %s", err) diff --git a/client_multiplexer.go b/client_multiplexer.go index e2b1f717c..1b9a28824 100644 --- a/client_multiplexer.go +++ b/client_multiplexer.go @@ -87,7 +87,7 @@ func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManag rcvTime := time.Now() r := bytes.NewReader(data) - iHdr, err := wire.ParseInvariantHeader(r) + iHdr, err := wire.ParseInvariantHeader(r, protocol.ConnectionIDLenGQUIC) // drop the packet if we can't parse the header if err != nil { m.logger.Debugf("error parsing invariant header from %s: %s", addr, err) diff --git a/client_test.go b/client_test.go index 39284af24..30d2bb5c1 100644 --- a/client_test.go +++ b/client_test.go @@ -866,7 +866,7 @@ var _ = Describe("Client", func() { pr = wire.WritePublicReset(cl.destConnID, 1, 0) r := bytes.NewReader(pr) - iHdr, err := wire.ParseInvariantHeader(r) + iHdr, err := wire.ParseInvariantHeader(r, 0) Expect(err).ToNot(HaveOccurred()) hdr, err = iHdr.Parse(r, protocol.PerspectiveServer, versionGQUICFrames) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go index dca4bcd6d..cf2f480df 100644 --- a/internal/protocol/connection_id.go +++ b/internal/protocol/connection_id.go @@ -12,7 +12,7 @@ type ConnectionID []byte // GenerateConnectionID generates a connection ID using cryptographic random func GenerateConnectionID() (ConnectionID, error) { - b := make([]byte, ConnectionIDLen) + b := make([]byte, ConnectionIDLenGQUIC) if _, err := rand.Read(b); err != nil { return nil, err } diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index e89b22279..2a52f8952 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -82,3 +82,9 @@ const MinInitialPacketSize = 1200 // * one failure due to an incorrect or missing source-address token // * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token const MaxClientHellos = 3 + +// ConnectionIDLenGQUIC is the length of the source Connection ID used on gQUIC QUIC packets. +const ConnectionIDLenGQUIC = 8 + +// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet. +const MinConnectionIDLenInitial = 8 diff --git a/internal/protocol/server_parameters.go b/internal/protocol/server_parameters.go index 92fb0dd27..c8696fdff 100644 --- a/internal/protocol/server_parameters.go +++ b/internal/protocol/server_parameters.go @@ -145,12 +145,3 @@ const MaxAckFrameSize ByteCount = 1000 // If the packet packing frequency is higher, multiple packets might be sent at once. // Example: For a packet pacing delay of 20 microseconds, we would send 5 packets at once, wait for 100 microseconds, and so forth. const MinPacingDelay time.Duration = 100 * time.Microsecond - -// ConnectionIDLen is the length of the source Connection ID used on IETF QUIC packets. -// The Short Header contains the connection ID, but not the length, -// so we need to know this value in advance (or encode it into the connection ID). -// TODO: make this configurable -const ConnectionIDLen = 8 - -// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet. -const MinConnectionIDLenInitial = 8 diff --git a/internal/wire/header.go b/internal/wire/header.go index d53a9872b..99efab75d 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -56,9 +56,6 @@ func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, version proto // TODO: add support for the key phase func (h *Header) writeLongHeader(b *bytes.Buffer) error { - if h.SrcConnectionID.Len() != protocol.ConnectionIDLen { - return fmt.Errorf("Header: source connection ID must be %d bytes, is %d", protocol.ConnectionIDLen, h.SrcConnectionID.Len()) - } b.WriteByte(byte(0x80 | h.Type)) utils.BigEndian.WriteUint32(b, uint32(h.Version)) connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID) @@ -174,7 +171,7 @@ func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) { return 0, errPacketNumberLenNotSet } length += protocol.ByteCount(h.PacketNumberLen) - length += protocol.ByteCount(h.DestConnectionID.Len()) // if set, always 8 bytes + length += protocol.ByteCount(h.DestConnectionID.Len()) // Version Number in packets sent by the client if h.VersionFlag { length += 4 diff --git a/internal/wire/header_parser.go b/internal/wire/header_parser.go index 523c24fe7..e712f76ac 100644 --- a/internal/wire/header_parser.go +++ b/internal/wire/header_parser.go @@ -21,7 +21,7 @@ type InvariantHeader struct { } // ParseInvariantHeader parses the version independent part of the header -func ParseInvariantHeader(b *bytes.Reader) (*InvariantHeader, error) { +func ParseInvariantHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*InvariantHeader, error) { typeByte, err := b.ReadByte() if err != nil { return nil, err @@ -36,8 +36,15 @@ func ParseInvariantHeader(b *bytes.Reader) (*InvariantHeader, error) { // In the IETF Short Header: // * 0x8 it is the gQUIC Demultiplexing bit, and always 0. // * 0x20 and 0x10 are always 1. - if typeByte&0x8 > 0 || typeByte&0x38 == 0x30 { - h.DestConnectionID, err = protocol.ReadConnectionID(b, 8) + var connIDLen int + if typeByte&0x8 > 0 { // Public Header containing a connection ID + connIDLen = 8 + } + if typeByte&0x38 == 0x30 { // Short Header + connIDLen = shortHeaderConnIDLen + } + if connIDLen > 0 { + h.DestConnectionID, err = protocol.ReadConnectionID(b, connIDLen) if err != nil { return nil, err } diff --git a/internal/wire/header_parser_test.go b/internal/wire/header_parser_test.go index c7d627762..769abf1f4 100644 --- a/internal/wire/header_parser_test.go +++ b/internal/wire/header_parser_test.go @@ -22,13 +22,13 @@ var _ = Describe("Header Parsing", func() { Context("Version Negotiation Packets", func() { It("parses", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} versions := []protocol.VersionNumber{0x22334455, 0x33445566} data, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.DestConnectionID).To(Equal(destConnID)) Expect(iHdr.SrcConnectionID).To(Equal(srcConnID)) @@ -50,7 +50,7 @@ var _ = Describe("Header Parsing", func() { data, err := ComposeVersionNegotiation(connID, connID, versions) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data[:len(data)-2]) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) _, err = iHdr.Parse(b, protocol.PerspectiveServer, versionIETFFrames) Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket)) @@ -63,7 +63,7 @@ var _ = Describe("Header Parsing", func() { Expect(err).ToNot(HaveOccurred()) // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number b := bytes.NewReader(data[:len(data)-8]) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) _, err = iHdr.Parse(b, protocol.PerspectiveServer, versionIETFFrames) Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list")) @@ -71,13 +71,13 @@ var _ = Describe("Header Parsing", func() { }) Context("Long Headers", func() { - It("parses a long header", func() { - destConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} - srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x42, 0x42} + It("parses a Long Header", func() { + destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} + srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} data := []byte{ 0x80 ^ uint8(protocol.PacketTypeInitial), 0x1, 0x2, 0x3, 0x4, // version number - 0x55, // connection ID lengths + 0x61, // connection ID lengths } data = append(data, destConnID...) data = append(data, srcConnID...) @@ -86,7 +86,7 @@ var _ = Describe("Header Parsing", func() { data = appendPacketNumber(data, 0xbeef, protocol.PacketNumberLen4) b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.IsLongHeader).To(BeTrue()) Expect(iHdr.DestConnectionID).To(Equal(destConnID)) @@ -105,7 +105,7 @@ var _ = Describe("Header Parsing", func() { Expect(b.Len()).To(BeZero()) }) - It("parses a long header without a destination connection ID", func() { + It("parses a Long Header without a destination connection ID", func() { data := []byte{ 0x80 ^ uint8(protocol.PacketTypeInitial), 0x1, 0x2, 0x3, 0x4, // version number @@ -115,13 +115,13 @@ var _ = Describe("Header Parsing", func() { data = append(data, encodeVarInt(0x42)...) // payload length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) Expect(iHdr.DestConnectionID).To(BeEmpty()) }) - It("parses a long header without a source connection ID", func() { + It("parses a Long Header without a source connection ID", func() { data := []byte{ 0x80 ^ uint8(protocol.PacketTypeInitial), 0x1, 0x2, 0x3, 0x4, // version number @@ -131,13 +131,13 @@ var _ = Describe("Header Parsing", func() { data = append(data, encodeVarInt(0x42)...) // payload length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.SrcConnectionID).To(BeEmpty()) Expect(iHdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) }) - It("parses a long header with a 2 byte packet number", func() { + It("parses a Long Header with a 2 byte packet number", func() { data := []byte{ 0x80 ^ uint8(protocol.PacketTypeInitial), 0x1, 0x2, 0x3, 0x4, // version number @@ -146,7 +146,7 @@ var _ = Describe("Header Parsing", func() { data = append(data, encodeVarInt(0x42)...) // payload length data = appendPacketNumber(data, 0x123, protocol.PacketNumberLen2) b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionIETFHeader) Expect(err).ToNot(HaveOccurred()) @@ -167,7 +167,7 @@ var _ = Describe("Header Parsing", func() { }).Write(buf, protocol.PerspectiveClient, protocol.VersionTLS) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(buf.Bytes()) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) _, err = iHdr.Parse(b, protocol.PerspectiveClient, versionIETFHeader) Expect(err).To(MatchError("InvalidPacketHeader: Received packet with invalid packet type: 42")) @@ -182,7 +182,7 @@ var _ = Describe("Header Parsing", func() { 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // source connection ID } for i := 0; i < len(data); i++ { - _, err := ParseInvariantHeader(bytes.NewReader(data[:i])) + _, err := ParseInvariantHeader(bytes.NewReader(data[:i]), 0) Expect(err).To(Equal(io.EOF)) } }) @@ -198,7 +198,7 @@ var _ = Describe("Header Parsing", func() { data = appendPacketNumber(data, 0xdeadbeef, protocol.PacketNumberLen4) for i := iHdrLen; i < len(data); i++ { b := bytes.NewReader(data[:i]) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) _, err = iHdr.Parse(b, protocol.PerspectiveServer, versionIETFHeader) Expect(err).To(Equal(io.EOF)) @@ -207,12 +207,12 @@ var _ = Describe("Header Parsing", func() { }) Context("Short Headers", func() { - It("reads a short header with a connection ID", func() { + It("reads a Short Header with a 8 byte connection ID", func() { connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} data := append([]byte{0x30}, connID...) data = appendPacketNumber(data, 0x42, protocol.PacketNumberLen1) b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 8) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.IsLongHeader).To(BeFalse()) Expect(iHdr.DestConnectionID).To(Equal(connID)) @@ -226,14 +226,31 @@ var _ = Describe("Header Parsing", func() { Expect(b.Len()).To(BeZero()) }) + It("reads a Short Header with a 5 byte connection ID", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5} + data := append([]byte{0x30}, connID...) + data = appendPacketNumber(data, 0x42, protocol.PacketNumberLen1) + b := bytes.NewReader(data) + iHdr, err := ParseInvariantHeader(b, 5) + Expect(err).ToNot(HaveOccurred()) + Expect(iHdr.IsLongHeader).To(BeFalse()) + Expect(iHdr.DestConnectionID).To(Equal(connID)) + hdr, err := iHdr.Parse(b, protocol.PerspectiveClient, versionIETFHeader) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.KeyPhase).To(Equal(0)) + Expect(hdr.DestConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(BeEmpty()) + Expect(b.Len()).To(BeZero()) + }) + It("reads the Key Phase Bit", func() { data := []byte{ 0x30 ^ 0x40, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID } data = appendPacketNumber(data, 11, protocol.PacketNumberLen1) b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 6) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionIETFHeader) Expect(err).ToNot(HaveOccurred()) @@ -245,11 +262,11 @@ var _ = Describe("Header Parsing", func() { It("reads a header with a 2 byte packet number", func() { data := []byte{ 0x30 ^ 0x40 ^ 0x1, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID + 0xde, 0xad, 0xbe, 0xef, // connection ID } data = appendPacketNumber(data, 0x1337, protocol.PacketNumberLen2) b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 4) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(b, protocol.PerspectiveClient, versionIETFHeader) Expect(err).ToNot(HaveOccurred()) @@ -262,11 +279,11 @@ var _ = Describe("Header Parsing", func() { It("reads a header with a 4 byte packet number", func() { data := []byte{ 0x30 ^ 0x40 ^ 0x2, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x1, 0x2, 0x3, 0x4, // connection ID } data = appendPacketNumber(data, 0x99beef, protocol.PacketNumberLen4) b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 10) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionIETFHeader) Expect(err).ToNot(HaveOccurred()) @@ -282,7 +299,7 @@ var _ = Describe("Header Parsing", func() { 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID } for i := 0; i < len(data); i++ { - _, err := ParseInvariantHeader(bytes.NewReader(data[:i])) + _, err := ParseInvariantHeader(bytes.NewReader(data[:i]), 8) Expect(err).To(Equal(io.EOF)) } }) @@ -290,13 +307,13 @@ var _ = Describe("Header Parsing", func() { It("errors on EOF, when parsing the invariant header", func() { data := []byte{ 0x30 ^ 0x2, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID } iHdrLen := len(data) data = appendPacketNumber(data, 0xdeadbeef, protocol.PacketNumberLen4) for i := iHdrLen; i < len(data); i++ { b := bytes.NewReader(data[:i]) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 6) Expect(err).ToNot(HaveOccurred()) _, err = iHdr.Parse(b, protocol.PerspectiveClient, versionIETFHeader) Expect(err).To(Equal(io.EOF)) @@ -307,10 +324,14 @@ var _ = Describe("Header Parsing", func() { Context("Public Header", func() { It("accepts a sample client header", func() { - ver := make([]byte, 4) - binary.BigEndian.PutUint32(ver, uint32(protocol.SupportedVersions[0])) - b := bytes.NewReader(append(append([]byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}, ver...), 0x01)) - iHdr, err := ParseInvariantHeader(b) + data := []byte{ + 0x9, + 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, + } + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) + data = append(data, 0x1) // packet number + b := bytes.NewReader(data) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.IsLongHeader).To(BeFalse()) hdr, err := iHdr.Parse(b, protocol.PerspectiveClient, versionPublicHeader) @@ -321,7 +342,7 @@ var _ = Describe("Header Parsing", func() { connID := protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} Expect(hdr.DestConnectionID).To(Equal(connID)) Expect(hdr.SrcConnectionID).To(BeEmpty()) - Expect(hdr.Version).To(Equal(protocol.SupportedVersions[0])) + Expect(hdr.Version).To(Equal(protocol.VersionNumber(0xdeadbeef))) Expect(hdr.SupportedVersions).To(BeEmpty()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(1))) Expect(b.Len()).To(BeZero()) @@ -329,7 +350,7 @@ var _ = Describe("Header Parsing", func() { It("accepts an omitted connection ID", func() { b := bytes.NewReader([]byte{0x0, 0x1}) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 8) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.IsLongHeader).To(BeFalse()) Expect(iHdr.DestConnectionID).To(BeEmpty()) @@ -342,7 +363,7 @@ var _ = Describe("Header Parsing", func() { It("parses a PUBLIC_RESET packet", func() { b := bytes.NewReader([]byte{0xa, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 4) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.IsLongHeader).To(BeFalse()) hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader) @@ -359,7 +380,7 @@ var _ = Describe("Header Parsing", func() { divNonce := []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f} Expect(divNonce).To(HaveLen(32)) b := bytes.NewReader(append(append([]byte{0x0c, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}, divNonce...), 0x37)) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 7) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.IsLongHeader).To(BeFalse()) hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader) @@ -380,7 +401,7 @@ var _ = Describe("Header Parsing", func() { data = append(data, []byte{0x13, 37}...) // packet number for i := iHdrLen; i < len(data); i++ { b := bytes.NewReader(data[:i]) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 5) Expect(err).ToNot(HaveOccurred()) _, err = iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader) Expect(err).To(Equal(io.EOF)) @@ -398,7 +419,7 @@ var _ = Describe("Header Parsing", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x13, 0x37} b := bytes.NewReader(ComposeGQUICVersionNegotiation(connID, versions)) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 6) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader) Expect(err).ToNot(HaveOccurred()) @@ -416,7 +437,7 @@ var _ = Describe("Header Parsing", func() { It("errors if it doesn't contain any versions", func() { b := bytes.NewReader([]byte{0x9, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 4) Expect(err).ToNot(HaveOccurred()) _, err = iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader) Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list")) @@ -428,7 +449,7 @@ var _ = Describe("Header Parsing", func() { data = appendVersion(data, protocol.SupportedVersions[0]) data = appendVersion(data, 99) // unsupported version b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader) Expect(err).ToNot(HaveOccurred()) @@ -442,7 +463,7 @@ var _ = Describe("Header Parsing", func() { data := ComposeGQUICVersionNegotiation(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, protocol.SupportedVersions) data = append(data, []byte{0x13, 0x37}...) b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) _, err = iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader) Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket)) @@ -452,7 +473,7 @@ var _ = Describe("Header Parsing", func() { Context("Packet Number lengths", func() { It("accepts 1-byte packet numbers", func() { b := bytes.NewReader([]byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xde}) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader) Expect(err).ToNot(HaveOccurred()) @@ -463,7 +484,7 @@ var _ = Describe("Header Parsing", func() { It("accepts 2-byte packet numbers", func() { b := bytes.NewReader([]byte{0x18, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xde, 0xca}) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader) Expect(err).ToNot(HaveOccurred()) @@ -474,7 +495,7 @@ var _ = Describe("Header Parsing", func() { It("accepts 4-byte packet numbers", func() { b := bytes.NewReader([]byte{0x28, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xad, 0xfb, 0xca, 0xde}) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 0) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index eeaa05471..0ac2987cf 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -84,21 +84,6 @@ var _ = Describe("Header", func() { Expect(err).To(MatchError("invalid connection ID length: 19 bytes")) }) - It("refuses to write a Long Header with the wrong connection ID length", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} - Expect(srcConnID).ToNot(Equal(protocol.ConnectionIDLen)) - err := (&Header{ - IsLongHeader: true, - Type: 0x5, - SrcConnectionID: srcConnID, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}, // connection IDs must be at most 18 bytes long - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - Version: 0x1020304, - }).Write(buf, protocol.PerspectiveServer, versionIETFHeader) - Expect(err).To(MatchError("Header: source connection ID must be 8 bytes, is 6")) - }) - It("writes a header with an 18 byte connection ID", func() { err := (&Header{ IsLongHeader: true, @@ -537,7 +522,7 @@ var _ = Describe("Header", func() { data, err := ComposeVersionNegotiation(destConnID, srcConnID, []protocol.VersionNumber{0x12345678, 0x87654321}) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 4) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionIETFHeader) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index b8194fd5d..0d53142a7 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -14,7 +14,7 @@ var _ = Describe("Version Negotiation Packets", func() { versions := []protocol.VersionNumber{1001, 1003} data := ComposeGQUICVersionNegotiation(connID, versions) b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 4) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader) Expect(err).ToNot(HaveOccurred()) @@ -32,7 +32,7 @@ var _ = Describe("Version Negotiation Packets", func() { Expect(err).ToNot(HaveOccurred()) Expect(data[0] & 0x80).ToNot(BeZero()) b := bytes.NewReader(data) - iHdr, err := ParseInvariantHeader(b) + iHdr, err := ParseInvariantHeader(b, 4) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionIETFHeader) Expect(err).ToNot(HaveOccurred()) diff --git a/packet_packer_test.go b/packet_packer_test.go index 7cd66febf..81fa87081 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -65,7 +65,7 @@ var _ = Describe("Packet packer", func() { checkPayloadLen := func(data []byte) { r := bytes.NewReader(data) - iHdr, err := wire.ParseInvariantHeader(r) + iHdr, err := wire.ParseInvariantHeader(r, 0) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) diff --git a/server.go b/server.go index da791ab6c..fe6158fb0 100644 --- a/server.go +++ b/server.go @@ -304,7 +304,7 @@ func (s *server) handlePacket(remoteAddr net.Addr, packet []byte) error { rcvTime := time.Now() r := bytes.NewReader(packet) - iHdr, err := wire.ParseInvariantHeader(r) + iHdr, err := wire.ParseInvariantHeader(r, protocol.ConnectionIDLenGQUIC) if err != nil { return qerr.Error(qerr.InvalidPacketHeader, err.Error()) } diff --git a/server_test.go b/server_test.go index 140648ff5..2e10e7d7d 100644 --- a/server_test.go +++ b/server_test.go @@ -500,7 +500,7 @@ var _ = Describe("Server", func() { Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero()) Expect(conn.dataWrittenTo).To(Equal(udpAddr)) r := bytes.NewReader(conn.dataWritten.Bytes()) - iHdr, err := wire.ParseInvariantHeader(r) + iHdr, err := wire.ParseInvariantHeader(r, 0) Expect(err).ToNot(HaveOccurred()) Expect(iHdr.IsLongHeader).To(BeFalse()) replyHdr, err := iHdr.Parse(r, protocol.PerspectiveServer, versionIETFFrames) @@ -546,7 +546,7 @@ var _ = Describe("Server", func() { Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero()) Expect(conn.dataWrittenTo).To(Equal(udpAddr)) r := bytes.NewReader(conn.dataWritten.Bytes()) - iHdr, err := wire.ParseInvariantHeader(r) + iHdr, err := wire.ParseInvariantHeader(r, 0) Expect(err).ToNot(HaveOccurred()) replyHdr, err := iHdr.Parse(r, protocol.PerspectiveServer, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) diff --git a/server_tls_test.go b/server_tls_test.go index 2f1f32de4..b0e5af1a6 100644 --- a/server_tls_test.go +++ b/server_tls_test.go @@ -73,7 +73,7 @@ var _ = Describe("Stateless TLS handling", func() { unpackPacket := func(data []byte) (*wire.Header, []byte) { r := bytes.NewReader(conn.dataWritten.Bytes()) - iHdr, err := wire.ParseInvariantHeader(r) + iHdr, err := wire.ParseInvariantHeader(r, 0) Expect(err).ToNot(HaveOccurred()) hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) diff --git a/session_test.go b/session_test.go index 2926f18ec..12ad17eab 100644 --- a/session_test.go +++ b/session_test.go @@ -1823,7 +1823,7 @@ var _ = Describe("Client Session", func() { sess.queueControlFrame(&wire.PingFrame{}) var packet []byte Eventually(mconn.written).Should(Receive(&packet)) - hdr, err := wire.ParseInvariantHeader(bytes.NewReader(packet)) + hdr, err := wire.ParseInvariantHeader(bytes.NewReader(packet), 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7})) // make sure the go routine returns From 73f7636537b991661405222e548b2e5bd79483f7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 1 Jul 2018 13:48:47 +0700 Subject: [PATCH 2/3] use a random length destination connection ID on the Initial packet The destination connection ID on the Initial packet must be at least 8 bytes long. By using all valid values, we make sure that the everything works correctly. The server chooses a new connection ID with the Retry or Handshake packet it sends, so the overhead of this is negligible. --- client.go | 47 ++++++++++++------------- client_test.go | 4 +-- internal/protocol/connection_id.go | 17 +++++++-- internal/protocol/connection_id_test.go | 28 +++++++++++++-- server_tls.go | 8 +++-- server_tls_test.go | 12 +++---- 6 files changed, 78 insertions(+), 38 deletions(-) diff --git a/client.go b/client.go index 6937b3441..1a3d9286a 100644 --- a/client.go +++ b/client.go @@ -140,23 +140,13 @@ func newClient( ) (*client, error) { clientConfig := populateClientConfig(config) version := clientConfig.Versions[0] - srcConnID, err := generateConnectionID() - if err != nil { - return nil, err - } - destConnID := srcConnID - if version.UsesTLS() { - destConnID, err = generateConnectionID() - if err != nil { - return nil, err - } - } var hostname string if tlsConf != nil { hostname = tlsConf.ServerName } if hostname == "" { + var err error hostname, _, err = net.SplitHostPort(host) if err != nil { return nil, err @@ -175,10 +165,8 @@ func newClient( if closeCallback != nil { onClose = closeCallback } - return &client{ + c := &client{ conn: &conn{pconn: pconn, currentAddr: remoteAddr}, - srcConnID: srcConnID, - destConnID: destConnID, hostname: hostname, tlsConf: tlsConf, config: clientConfig, @@ -186,7 +174,8 @@ func newClient( handshakeChan: make(chan struct{}), closeCallback: onClose, logger: utils.DefaultLogger.WithPrefix("client"), - }, nil + } + return c, c.generateConnectionIDs() } // populateClientConfig populates fields in the quic.Config with their default values, if none are set @@ -243,6 +232,23 @@ func populateClientConfig(config *Config) *Config { } } +func (c *client) generateConnectionIDs() error { + srcConnID, err := generateConnectionID(protocol.ConnectionIDLenGQUIC) + if err != nil { + return err + } + destConnID := srcConnID + if c.version.UsesTLS() { + destConnID, err = protocol.GenerateDestinationConnectionID() + if err != nil { + return err + } + } + c.srcConnID = srcConnID + c.destConnID = destConnID + return nil +} + func (c *client) dial(ctx context.Context) error { c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) @@ -506,15 +512,8 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { // switch to negotiated version c.initialVersion = c.version c.version = newVersion - var err error - c.destConnID, err = generateConnectionID() - if err != nil { - return err - } - // in gQUIC, there's only one connection ID - if !c.version.UsesTLS() { - c.srcConnID = c.destConnID - } + c.generateConnectionIDs() + c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID) c.session.Close(errCloseSessionForNewVersion) return nil diff --git a/client_test.go b/client_test.go index 30d2bb5c1..d0856ab69 100644 --- a/client_test.go +++ b/client_test.go @@ -81,11 +81,11 @@ var _ = Describe("Client", func() { }) Context("Dialing", func() { - var origGenerateConnectionID func() (protocol.ConnectionID, error) + var origGenerateConnectionID func(int) (protocol.ConnectionID, error) BeforeEach(func() { origGenerateConnectionID = generateConnectionID - generateConnectionID = func() (protocol.ConnectionID, error) { + generateConnectionID = func(int) (protocol.ConnectionID, error) { return connID, nil } }) diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go index cf2f480df..beacbfcfd 100644 --- a/internal/protocol/connection_id.go +++ b/internal/protocol/connection_id.go @@ -10,15 +10,28 @@ import ( // A ConnectionID in QUIC type ConnectionID []byte +const maxConnectionIDLen = 18 + // GenerateConnectionID generates a connection ID using cryptographic random -func GenerateConnectionID() (ConnectionID, error) { - b := make([]byte, ConnectionIDLenGQUIC) +func GenerateConnectionID(len int) (ConnectionID, error) { + b := make([]byte, len) if _, err := rand.Read(b); err != nil { return nil, err } return ConnectionID(b), nil } +// GenerateDestinationConnectionID generates a connection ID for the Initial packet. +// It uses a length randomly chosen between 8 and 18 bytes. +func GenerateDestinationConnectionID() (ConnectionID, error) { + r := make([]byte, 1) + if _, err := rand.Read(r); err != nil { + return nil, err + } + len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) + return GenerateConnectionID(len) +} + // ReadConnectionID reads a connection ID of length len from the given io.Reader. // It returns io.EOF if there are not enough bytes to read. func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) { diff --git a/internal/protocol/connection_id_test.go b/internal/protocol/connection_id_test.go index 9f7d17de7..3d0d90e2c 100644 --- a/internal/protocol/connection_id_test.go +++ b/internal/protocol/connection_id_test.go @@ -10,14 +10,38 @@ import ( var _ = Describe("Connection ID generation", func() { It("generates random connection IDs", func() { - c1, err := GenerateConnectionID() + c1, err := GenerateConnectionID(8) Expect(err).ToNot(HaveOccurred()) Expect(c1).ToNot(BeZero()) - c2, err := GenerateConnectionID() + c2, err := GenerateConnectionID(8) Expect(err).ToNot(HaveOccurred()) Expect(c1).ToNot(Equal(c2)) }) + It("generates connection IDs with the requested length", func() { + c, err := GenerateConnectionID(5) + Expect(err).ToNot(HaveOccurred()) + Expect(c.Len()).To(Equal(5)) + }) + + It("generates random length destination connection IDs", func() { + var has8ByteConnID, has18ByteConnID bool + for i := 0; i < 1000; i++ { + c, err := GenerateDestinationConnectionID() + Expect(err).ToNot(HaveOccurred()) + Expect(c.Len()).To(BeNumerically(">=", 8)) + Expect(c.Len()).To(BeNumerically("<=", 18)) + if c.Len() == 8 { + has8ByteConnID = true + } + if c.Len() == 18 { + has18ByteConnID = true + } + } + Expect(has8ByteConnID).To(BeTrue()) + Expect(has18ByteConnID).To(BeTrue()) + }) + It("says if connection IDs are equal", func() { c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} diff --git a/server_tls.go b/server_tls.go index b88372429..5d303b0c8 100644 --- a/server_tls.go +++ b/server_tls.go @@ -194,11 +194,15 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, StreamID: version.CryptoStreamID(), Data: bc.GetDataForWriting(), } + srcConnID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC) + if err != nil { + return nil, nil, err + } replyHdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeRetry, DestConnectionID: hdr.SrcConnectionID, - SrcConnectionID: hdr.DestConnectionID, + SrcConnectionID: srcConnID, PayloadLen: f.Length(version) + protocol.ByteCount(aead.Overhead()), PacketNumber: hdr.PacketNumber, // echo the client's packet number PacketNumberLen: hdr.PacketNumberLen, @@ -224,7 +228,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State()) } params := <-paramsChan - connID, err := protocol.GenerateConnectionID() + connID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC) if err != nil { return nil, nil, err } diff --git a/server_tls_test.go b/server_tls_test.go index b0e5af1a6..e7c804d1e 100644 --- a/server_tls_test.go +++ b/server_tls_test.go @@ -71,7 +71,7 @@ var _ = Describe("Stateless TLS handling", func() { return hdr, data } - unpackPacket := func(data []byte) (*wire.Header, []byte) { + unpackPacket := func(data []byte, clientDestConnID protocol.ConnectionID) (*wire.Header, []byte) { r := bytes.NewReader(conn.dataWritten.Bytes()) iHdr, err := wire.ParseInvariantHeader(r, 0) Expect(err).ToNot(HaveOccurred()) @@ -80,7 +80,7 @@ var _ = Describe("Stateless TLS handling", func() { hdr.Raw = data[:len(data)-r.Len()] var payload []byte if r.Len() > 0 { - aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, hdr.SrcConnectionID, protocol.VersionTLS) + aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, clientDestConnID, protocol.VersionTLS) Expect(err).ToNot(HaveOccurred()) payload, err = aead.Open(nil, data[len(data)-r.Len():], hdr.PacketNumber, hdr.Raw) Expect(err).ToNot(HaveOccurred()) @@ -97,7 +97,7 @@ var _ = Describe("Stateless TLS handling", func() { } server.HandleInitial(nil, hdr, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize)) Expect(conn.dataWritten.Len()).ToNot(BeZero()) - replyHdr, _ := unpackPacket(conn.dataWritten.Bytes()) + replyHdr, _ := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID) Expect(replyHdr.IsVersionNegotiation).To(BeTrue()) Expect(sessionChan).ToNot(Receive()) }) @@ -134,9 +134,9 @@ var _ = Describe("Stateless TLS handling", func() { hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")}) server.HandleInitial(nil, hdr, data) Expect(conn.dataWritten.Len()).ToNot(BeZero()) - replyHdr, payload := unpackPacket(conn.dataWritten.Bytes()) + replyHdr, payload := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID) Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) - Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) Expect(replyHdr.PayloadLen).To(BeEquivalentTo(len(payload) + 16 /* AEAD overhead */)) Expect(sessionChan).ToNot(Receive()) @@ -187,7 +187,7 @@ var _ = Describe("Stateless TLS handling", func() { // the Handshake packet is written by the session Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty()) // unpack the packet to check that it actually contains a CONNECTION_CLOSE - replyHdr, data := unpackPacket(conn.dataWritten.Bytes()) + replyHdr, data := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID) Expect(replyHdr.Type).To(Equal(protocol.PacketTypeHandshake)) Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) From f02dc92a32a1b590ceb52b5ded1fdd42aca66c55 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 1 Jul 2018 16:03:33 +0700 Subject: [PATCH 3/3] make the connection ID length configurable --- Changelog.md | 4 +++ client.go | 29 +++++++++++----- client_multiplexer.go | 40 +++++++++++++--------- client_multiplexer_test.go | 25 +++++++++----- client_test.go | 46 +++++++++++++++++--------- interface.go | 7 ++++ internal/protocol/server_parameters.go | 4 +++ mock_multiplexer_test.go | 11 +++--- server.go | 7 +++- server_test.go | 8 +++++ server_tls.go | 4 +-- session_test.go | 2 +- 12 files changed, 131 insertions(+), 56 deletions(-) diff --git a/Changelog.md b/Changelog.md index 7f22ae3ad..490541375 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,5 +1,9 @@ # Changelog +## v0.9.0 (unreleased) + +- Add a `quic.Config` option for the length of the connection ID (for IETF QUIC). + ## v0.8.0 (2018-06-26) - Add support for unidirectional streams (for IETF QUIC). diff --git a/client.go b/client.go index 1a3d9286a..ee2f6a3cd 100644 --- a/client.go +++ b/client.go @@ -74,6 +74,7 @@ func DialAddrContext( tlsConf *tls.Config, config *Config, ) (Session, error) { + config = populateClientConfig(config, false) udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err @@ -115,8 +116,12 @@ func DialContext( tlsConf *tls.Config, config *Config, ) (Session, error) { + config = populateClientConfig(config, true) multiplexer := getClientMultiplexer() - manager := multiplexer.AddConn(pconn) + manager, err := multiplexer.AddConn(pconn, config.ConnectionIDLength) + if err != nil { + return nil, err + } c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove) if err != nil { return nil, err @@ -138,9 +143,6 @@ func newClient( host string, closeCallback func(protocol.ConnectionID), ) (*client, error) { - clientConfig := populateClientConfig(config) - version := clientConfig.Versions[0] - var hostname string if tlsConf != nil { hostname = tlsConf.ServerName @@ -169,8 +171,8 @@ func newClient( conn: &conn{pconn: pconn, currentAddr: remoteAddr}, hostname: hostname, tlsConf: tlsConf, - config: clientConfig, - version: version, + config: config, + version: config.Versions[0], handshakeChan: make(chan struct{}), closeCallback: onClose, logger: utils.DefaultLogger.WithPrefix("client"), @@ -180,7 +182,7 @@ func newClient( // populateClientConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil -func populateClientConfig(config *Config) *Config { +func populateClientConfig(config *Config, onPacketConn bool) *Config { if config == nil { config = &Config{} } @@ -218,12 +220,17 @@ func populateClientConfig(config *Config) *Config { } else if maxIncomingUniStreams < 0 { maxIncomingUniStreams = 0 } + connIDLen := config.ConnectionIDLength + if connIDLen == 0 && onPacketConn { + connIDLen = protocol.DefaultConnectionIDLength + } return &Config{ Versions: versions, HandshakeTimeout: handshakeTimeout, IdleTimeout: idleTimeout, RequestConnectionIDOmission: config.RequestConnectionIDOmission, + ConnectionIDLength: connIDLen, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, MaxIncomingStreams: maxIncomingStreams, @@ -233,7 +240,11 @@ func populateClientConfig(config *Config) *Config { } func (c *client) generateConnectionIDs() error { - srcConnID, err := generateConnectionID(protocol.ConnectionIDLenGQUIC) + connIDLen := protocol.ConnectionIDLenGQUIC + if c.version.UsesTLS() { + connIDLen = c.config.ConnectionIDLength + } + srcConnID, err := generateConnectionID(connIDLen) if err != nil { return err } @@ -370,7 +381,7 @@ func (c *client) handleRead(remoteAddr net.Addr, packet []byte) { rcvTime := time.Now() r := bytes.NewReader(packet) - iHdr, err := wire.ParseInvariantHeader(r, protocol.ConnectionIDLenGQUIC) + iHdr, err := wire.ParseInvariantHeader(r, c.config.ConnectionIDLength) // drop the packet if we can't parse the header if err != nil { c.logger.Errorf("error parsing invariant header: %s", err) diff --git a/client_multiplexer.go b/client_multiplexer.go index 1b9a28824..b53c315f7 100644 --- a/client_multiplexer.go +++ b/client_multiplexer.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "errors" + "fmt" "net" "strings" "sync" @@ -19,16 +20,21 @@ var ( ) type multiplexer interface { - AddConn(net.PacketConn) packetHandlerManager + AddConn(net.PacketConn, int) (packetHandlerManager, error) AddHandler(net.PacketConn, protocol.ConnectionID, packetHandler) error } +type connManager struct { + connIDLen int + manager packetHandlerManager +} + // The clientMultiplexer listens on multiple net.PacketConns and dispatches // incoming packets to the session handler. type clientMultiplexer struct { mutex sync.Mutex - conns map[net.PacketConn]packetHandlerManager + conns map[net.PacketConn]connManager newPacketHandlerManager func() packetHandlerManager // so it can be replaced in the tests logger utils.Logger @@ -39,7 +45,7 @@ var _ multiplexer = &clientMultiplexer{} func getClientMultiplexer() multiplexer { clientMuxerOnce.Do(func() { clientMuxer = &clientMultiplexer{ - conns: make(map[net.PacketConn]packetHandlerManager), + conns: make(map[net.PacketConn]connManager), logger: utils.DefaultLogger.WithPrefix("client muxer"), newPacketHandlerManager: newPacketHandlerMap, } @@ -47,30 +53,34 @@ func getClientMultiplexer() multiplexer { return clientMuxer } -func (m *clientMultiplexer) AddConn(c net.PacketConn) packetHandlerManager { +func (m *clientMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandlerManager, error) { m.mutex.Lock() defer m.mutex.Unlock() - sessions, ok := m.conns[c] + p, ok := m.conns[c] if !ok { - sessions = m.newPacketHandlerManager() - m.conns[c] = sessions + manager := m.newPacketHandlerManager() + p = connManager{connIDLen: connIDLen, manager: manager} + m.conns[c] = p // If we didn't know this packet conn before, listen for incoming packets // and dispatch them to the right sessions. - go m.listen(c, sessions) + go m.listen(c, p) } - return sessions + if p.connIDLen != connIDLen { + return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) + } + return p.manager, nil } func (m *clientMultiplexer) AddHandler(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) error { - sessions, ok := m.conns[c] + p, ok := m.conns[c] if !ok { return errors.New("unknown packet conn %s") } - sessions.Add(connID, handler) + p.manager.Add(connID, handler) return nil } -func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManager) { +func (m *clientMultiplexer) listen(c net.PacketConn, p connManager) { for { data := *getPacketBuffer() data = data[:protocol.MaxReceivePacketSize] @@ -79,7 +89,7 @@ func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManag n, addr, err := c.ReadFrom(data) if err != nil { if !strings.HasSuffix(err.Error(), "use of closed network connection") { - sessions.Close(err) + p.manager.Close(err) } return } @@ -87,13 +97,13 @@ func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManag rcvTime := time.Now() r := bytes.NewReader(data) - iHdr, err := wire.ParseInvariantHeader(r, protocol.ConnectionIDLenGQUIC) + iHdr, err := wire.ParseInvariantHeader(r, p.connIDLen) // drop the packet if we can't parse the header if err != nil { m.logger.Debugf("error parsing invariant header from %s: %s", addr, err) continue } - client, ok := sessions.Get(iHdr.DestConnectionID) + client, ok := p.manager.Get(iHdr.DestConnectionID) if !ok { m.logger.Debugf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID) continue diff --git a/client_multiplexer_test.go b/client_multiplexer_test.go index 87109572e..b5f1d4051 100644 --- a/client_multiplexer_test.go +++ b/client_multiplexer_test.go @@ -29,11 +29,12 @@ var _ = Describe("Client Multiplexer", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} packetHandler := NewMockQuicSession(mockCtrl) handledPacket := make(chan struct{}) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(_ *receivedPacket) { + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + Expect(p.header.DestConnectionID).To(Equal(connID)) close(handledPacket) }) packetHandler.EXPECT().GetVersion() - getClientMultiplexer().AddConn(conn) + getClientMultiplexer().AddConn(conn, 8) err := getClientMultiplexer().AddHandler(conn, connID, packetHandler) Expect(err).ToNot(HaveOccurred()) conn.dataToRead <- getPacket(connID) @@ -43,6 +44,14 @@ var _ = Describe("Client Multiplexer", func() { close(conn.dataToRead) }) + It("errors when adding an existing conn with a different connection ID length", func() { + conn := newMockPacketConn() + _, err := getClientMultiplexer().AddConn(conn, 5) + Expect(err).ToNot(HaveOccurred()) + _, err = getClientMultiplexer().AddConn(conn, 6) + Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs")) + }) + It("errors when adding a handler for an unknown conn", func() { conn := newMockPacketConn() err := getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4}, NewMockQuicSession(mockCtrl)) @@ -67,7 +76,7 @@ var _ = Describe("Client Multiplexer", func() { close(handledPacket2) }) packetHandler2.EXPECT().GetVersion() - getClientMultiplexer().AddConn(conn) + getClientMultiplexer().AddConn(conn, connID1.Len()) Expect(getClientMultiplexer().AddHandler(conn, connID1, packetHandler1)).To(Succeed()) Expect(getClientMultiplexer().AddHandler(conn, connID2, packetHandler2)).To(Succeed()) @@ -84,10 +93,10 @@ var _ = Describe("Client Multiplexer", func() { It("drops unparseable packets", func() { conn := newMockPacketConn() - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} conn.dataToRead <- []byte("invalid header") packetHandler := NewMockQuicSession(mockCtrl) - getClientMultiplexer().AddConn(conn) + getClientMultiplexer().AddConn(conn, 7) Expect(getClientMultiplexer().AddHandler(conn, connID, packetHandler)).To(Succeed()) time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet packetHandler.EXPECT().Close(gomock.Any()).AnyTimes() @@ -106,7 +115,7 @@ var _ = Describe("Client Multiplexer", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} done := make(chan struct{}) manager.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(done) }).Return(nil, true) - getClientMultiplexer().AddConn(conn) + getClientMultiplexer().AddConn(conn, 8) conn.dataToRead <- getPacket(connID) Eventually(done).Should(BeClosed()) // makes the listen go routine return @@ -118,7 +127,7 @@ var _ = Describe("Client Multiplexer", func() { conn := newMockPacketConn() conn.dataToRead <- getPacket(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}) packetHandler := NewMockQuicSession(mockCtrl) - getClientMultiplexer().AddConn(conn) + getClientMultiplexer().AddConn(conn, 8) Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, packetHandler)).To(Succeed()) time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet // makes the listen go routine return @@ -135,7 +144,7 @@ var _ = Describe("Client Multiplexer", func() { packetHandler.EXPECT().Close(testErr).Do(func(error) { close(done) }) - getClientMultiplexer().AddConn(conn) + getClientMultiplexer().AddConn(conn, 8) Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)).To(Succeed()) Eventually(done).Should(BeClosed()) }) diff --git a/client_test.go b/client_test.go index d0856ab69..f653efec1 100644 --- a/client_test.go +++ b/client_test.go @@ -147,7 +147,7 @@ var _ = Describe("Client", func() { It("returns after the handshake is complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) run := make(chan struct{}) @@ -176,7 +176,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs while waiting for the connection to become secure", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) testErr := errors.New("early handshake error") @@ -203,7 +203,7 @@ var _ = Describe("Client", func() { It("closes the session when the context is canceled", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) sessionRunning := make(chan struct{}) @@ -243,7 +243,7 @@ var _ = Describe("Client", func() { It("removes closed sessions from the multiplexer", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Remove(connID) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) var runner sessionRunner @@ -279,18 +279,20 @@ var _ = Describe("Client", func() { RequestConnectionIDOmission: true, MaxIncomingStreams: 1234, MaxIncomingUniStreams: 4321, + ConnectionIDLength: 13, } - c := populateClientConfig(config) + c := populateClientConfig(config, false) Expect(c.HandshakeTimeout).To(Equal(1337 * time.Minute)) Expect(c.IdleTimeout).To(Equal(42 * time.Hour)) Expect(c.RequestConnectionIDOmission).To(BeTrue()) Expect(c.MaxIncomingStreams).To(Equal(1234)) Expect(c.MaxIncomingUniStreams).To(Equal(4321)) + Expect(c.ConnectionIDLength).To(Equal(13)) }) It("errors when the Config contains an invalid version", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) version := protocol.VersionNumber(0x1234) _, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}}) @@ -302,7 +304,7 @@ var _ = Describe("Client", func() { MaxIncomingStreams: -1, MaxIncomingUniStreams: 4321, } - c := populateClientConfig(config) + c := populateClientConfig(config, false) Expect(c.MaxIncomingStreams).To(BeZero()) Expect(c.MaxIncomingUniStreams).To(Equal(4321)) }) @@ -312,13 +314,25 @@ var _ = Describe("Client", func() { MaxIncomingStreams: 1234, MaxIncomingUniStreams: -1, } - c := populateClientConfig(config) + c := populateClientConfig(config, false) Expect(c.MaxIncomingStreams).To(Equal(1234)) Expect(c.MaxIncomingUniStreams).To(BeZero()) }) + It("uses 0-byte connection IDs when dialing an address", func() { + config := &Config{} + c := populateClientConfig(config, false) + Expect(c.ConnectionIDLength).To(BeZero()) + }) + + It("doesn't use 0-byte connection IDs when dialing an address", func() { + config := &Config{} + c := populateClientConfig(config, true) + Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) + }) + It("fills in default values if options are not set in the Config", func() { - c := populateClientConfig(&Config{}) + c := populateClientConfig(&Config{}, false) Expect(c.Versions).To(Equal(protocol.SupportedVersions)) Expect(c.HandshakeTimeout).To(Equal(protocol.DefaultHandshakeTimeout)) Expect(c.IdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) @@ -329,7 +343,7 @@ var _ = Describe("Client", func() { Context("gQUIC", func() { It("errors if it can't create a session", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) testErr := errors.New("error creating session") @@ -355,7 +369,7 @@ var _ = Describe("Client", func() { Context("IETF QUIC", func() { It("creates new TLS sessions with the right parameters", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} @@ -411,7 +425,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs during version negotiation", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) testErr := errors.New("early handshake error") @@ -568,6 +582,7 @@ var _ = Describe("Client", func() { }) It("drops version negotiation packets that contain the offered version", func() { + cl.config = &Config{} ver := cl.version cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver})) Expect(cl.version).To(Equal(ver)) @@ -581,6 +596,7 @@ var _ = Describe("Client", func() { }) It("ignores packets with an invalid public header", func() { + cl.config = &Config{} cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls cl.handleRead(addr, []byte("invalid packet")) }) @@ -682,7 +698,7 @@ var _ = Describe("Client", func() { It("creates new gQUIC sessions with the right parameters", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) config := &Config{Versions: protocol.SupportedVersions} @@ -723,7 +739,7 @@ var _ = Describe("Client", func() { It("creates a new session when the server performs a retry", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} @@ -757,7 +773,7 @@ var _ = Describe("Client", func() { It("only accepts one Retry packet", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} diff --git a/interface.go b/interface.go index 0fdb534ec..b1f19e523 100644 --- a/interface.go +++ b/interface.go @@ -165,6 +165,13 @@ type Config struct { // This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated. // Currently only valid for the client. RequestConnectionIDOmission bool + // The length of the connection ID in bytes. Only valid for IETF QUIC. + // It can be 0, or any value between 4 and 18. + // If not set, the interpretation depends on where the Config is used: + // If used for dialing an address, a 0 byte connection ID will be used. + // If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used. + // When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call. + ConnectionIDLength int // HandshakeTimeout is the maximum duration that the cryptographic handshake may take. // If the timeout is exceeded, the connection is closed. // If this value is zero, the timeout is set to 10 seconds. diff --git a/internal/protocol/server_parameters.go b/internal/protocol/server_parameters.go index c8696fdff..aa92c8223 100644 --- a/internal/protocol/server_parameters.go +++ b/internal/protocol/server_parameters.go @@ -145,3 +145,7 @@ const MaxAckFrameSize ByteCount = 1000 // If the packet packing frequency is higher, multiple packets might be sent at once. // Example: For a packet pacing delay of 20 microseconds, we would send 5 packets at once, wait for 100 microseconds, and so forth. const MinPacingDelay time.Duration = 100 * time.Microsecond + +// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections +// if no other value is configured. +const DefaultConnectionIDLength = 4 diff --git a/mock_multiplexer_test.go b/mock_multiplexer_test.go index d1e34a787..aa2f33d87 100644 --- a/mock_multiplexer_test.go +++ b/mock_multiplexer_test.go @@ -36,15 +36,16 @@ func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder { } // AddConn mocks base method -func (m *MockMultiplexer) AddConn(arg0 net.PacketConn) packetHandlerManager { - ret := m.ctrl.Call(m, "AddConn", arg0) +func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int) (packetHandlerManager, error) { + ret := m.ctrl.Call(m, "AddConn", arg0, arg1) ret0, _ := ret[0].(packetHandlerManager) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // AddConn indicates an expected call of AddConn -func (mr *MockMultiplexerMockRecorder) AddConn(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0) +func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1) } // AddHandler mocks base method diff --git a/server.go b/server.go index fe6158fb0..8b560c998 100644 --- a/server.go +++ b/server.go @@ -241,6 +241,10 @@ func populateServerConfig(config *Config) *Config { } else if maxIncomingUniStreams < 0 { maxIncomingUniStreams = 0 } + connIDLen := config.ConnectionIDLength + if connIDLen == 0 { + connIDLen = protocol.DefaultConnectionIDLength + } return &Config{ Versions: versions, @@ -252,6 +256,7 @@ func populateServerConfig(config *Config) *Config { MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, + ConnectionIDLength: connIDLen, } } @@ -304,7 +309,7 @@ func (s *server) handlePacket(remoteAddr net.Addr, packet []byte) error { rcvTime := time.Now() r := bytes.NewReader(packet) - iHdr, err := wire.ParseInvariantHeader(r, protocol.ConnectionIDLenGQUIC) + iHdr, err := wire.ParseInvariantHeader(r, s.config.ConnectionIDLength) if err != nil { return qerr.Error(qerr.InvalidPacketHeader, err.Error()) } diff --git a/server_test.go b/server_test.go index 2e10e7d7d..00fe8e526 100644 --- a/server_test.go +++ b/server_test.go @@ -48,6 +48,7 @@ var _ = Describe("Server", func() { RequestConnectionIDOmission: true, MaxIncomingStreams: 1234, MaxIncomingUniStreams: 4321, + ConnectionIDLength: 12, } c := populateServerConfig(config) Expect(c.HandshakeTimeout).To(Equal(1337 * time.Minute)) @@ -55,6 +56,7 @@ var _ = Describe("Server", func() { Expect(c.RequestConnectionIDOmission).To(BeFalse()) Expect(c.MaxIncomingStreams).To(Equal(1234)) Expect(c.MaxIncomingUniStreams).To(Equal(4321)) + Expect(c.ConnectionIDLength).To(Equal(12)) }) It("disables bidirectional streams", func() { @@ -76,6 +78,12 @@ var _ = Describe("Server", func() { Expect(c.MaxIncomingStreams).To(Equal(1234)) Expect(c.MaxIncomingUniStreams).To(BeZero()) }) + + It("doesn't use 0-byte connection IDs", func() { + config := &Config{} + c := populateClientConfig(config, true) + Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) + }) }) Context("with mock session", func() { diff --git a/server_tls.go b/server_tls.go index 5d303b0c8..4089af04d 100644 --- a/server_tls.go +++ b/server_tls.go @@ -194,7 +194,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, StreamID: version.CryptoStreamID(), Data: bc.GetDataForWriting(), } - srcConnID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC) + srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) if err != nil { return nil, nil, err } @@ -228,7 +228,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State()) } params := <-paramsChan - connID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC) + connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) if err != nil { return nil, nil, err } diff --git a/session_test.go b/session_test.go index 12ad17eab..f6225692e 100644 --- a/session_test.go +++ b/session_test.go @@ -1769,7 +1769,7 @@ var _ = Describe("Client Session", func() { protocol.Version39, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, nil, - populateClientConfig(&Config{}), + populateClientConfig(&Config{}, false), protocol.VersionWhatever, nil, utils.DefaultLogger,