From 73f7636537b991661405222e548b2e5bd79483f7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 1 Jul 2018 13:48:47 +0700 Subject: [PATCH] use a random length destination connection ID on the Initial packet The destination connection ID on the Initial packet must be at least 8 bytes long. By using all valid values, we make sure that the everything works correctly. The server chooses a new connection ID with the Retry or Handshake packet it sends, so the overhead of this is negligible. --- client.go | 47 ++++++++++++------------- client_test.go | 4 +-- internal/protocol/connection_id.go | 17 +++++++-- internal/protocol/connection_id_test.go | 28 +++++++++++++-- server_tls.go | 8 +++-- server_tls_test.go | 12 +++---- 6 files changed, 78 insertions(+), 38 deletions(-) diff --git a/client.go b/client.go index 6937b3441..1a3d9286a 100644 --- a/client.go +++ b/client.go @@ -140,23 +140,13 @@ func newClient( ) (*client, error) { clientConfig := populateClientConfig(config) version := clientConfig.Versions[0] - srcConnID, err := generateConnectionID() - if err != nil { - return nil, err - } - destConnID := srcConnID - if version.UsesTLS() { - destConnID, err = generateConnectionID() - if err != nil { - return nil, err - } - } var hostname string if tlsConf != nil { hostname = tlsConf.ServerName } if hostname == "" { + var err error hostname, _, err = net.SplitHostPort(host) if err != nil { return nil, err @@ -175,10 +165,8 @@ func newClient( if closeCallback != nil { onClose = closeCallback } - return &client{ + c := &client{ conn: &conn{pconn: pconn, currentAddr: remoteAddr}, - srcConnID: srcConnID, - destConnID: destConnID, hostname: hostname, tlsConf: tlsConf, config: clientConfig, @@ -186,7 +174,8 @@ func newClient( handshakeChan: make(chan struct{}), closeCallback: onClose, logger: utils.DefaultLogger.WithPrefix("client"), - }, nil + } + return c, c.generateConnectionIDs() } // populateClientConfig populates fields in the quic.Config with their default values, if none are set @@ -243,6 +232,23 @@ func populateClientConfig(config *Config) *Config { } } +func (c *client) generateConnectionIDs() error { + srcConnID, err := generateConnectionID(protocol.ConnectionIDLenGQUIC) + if err != nil { + return err + } + destConnID := srcConnID + if c.version.UsesTLS() { + destConnID, err = protocol.GenerateDestinationConnectionID() + if err != nil { + return err + } + } + c.srcConnID = srcConnID + c.destConnID = destConnID + return nil +} + func (c *client) dial(ctx context.Context) error { c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) @@ -506,15 +512,8 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { // switch to negotiated version c.initialVersion = c.version c.version = newVersion - var err error - c.destConnID, err = generateConnectionID() - if err != nil { - return err - } - // in gQUIC, there's only one connection ID - if !c.version.UsesTLS() { - c.srcConnID = c.destConnID - } + c.generateConnectionIDs() + c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID) c.session.Close(errCloseSessionForNewVersion) return nil diff --git a/client_test.go b/client_test.go index 30d2bb5c1..d0856ab69 100644 --- a/client_test.go +++ b/client_test.go @@ -81,11 +81,11 @@ var _ = Describe("Client", func() { }) Context("Dialing", func() { - var origGenerateConnectionID func() (protocol.ConnectionID, error) + var origGenerateConnectionID func(int) (protocol.ConnectionID, error) BeforeEach(func() { origGenerateConnectionID = generateConnectionID - generateConnectionID = func() (protocol.ConnectionID, error) { + generateConnectionID = func(int) (protocol.ConnectionID, error) { return connID, nil } }) diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go index cf2f480df..beacbfcfd 100644 --- a/internal/protocol/connection_id.go +++ b/internal/protocol/connection_id.go @@ -10,15 +10,28 @@ import ( // A ConnectionID in QUIC type ConnectionID []byte +const maxConnectionIDLen = 18 + // GenerateConnectionID generates a connection ID using cryptographic random -func GenerateConnectionID() (ConnectionID, error) { - b := make([]byte, ConnectionIDLenGQUIC) +func GenerateConnectionID(len int) (ConnectionID, error) { + b := make([]byte, len) if _, err := rand.Read(b); err != nil { return nil, err } return ConnectionID(b), nil } +// GenerateDestinationConnectionID generates a connection ID for the Initial packet. +// It uses a length randomly chosen between 8 and 18 bytes. +func GenerateDestinationConnectionID() (ConnectionID, error) { + r := make([]byte, 1) + if _, err := rand.Read(r); err != nil { + return nil, err + } + len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) + return GenerateConnectionID(len) +} + // ReadConnectionID reads a connection ID of length len from the given io.Reader. // It returns io.EOF if there are not enough bytes to read. func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) { diff --git a/internal/protocol/connection_id_test.go b/internal/protocol/connection_id_test.go index 9f7d17de7..3d0d90e2c 100644 --- a/internal/protocol/connection_id_test.go +++ b/internal/protocol/connection_id_test.go @@ -10,14 +10,38 @@ import ( var _ = Describe("Connection ID generation", func() { It("generates random connection IDs", func() { - c1, err := GenerateConnectionID() + c1, err := GenerateConnectionID(8) Expect(err).ToNot(HaveOccurred()) Expect(c1).ToNot(BeZero()) - c2, err := GenerateConnectionID() + c2, err := GenerateConnectionID(8) Expect(err).ToNot(HaveOccurred()) Expect(c1).ToNot(Equal(c2)) }) + It("generates connection IDs with the requested length", func() { + c, err := GenerateConnectionID(5) + Expect(err).ToNot(HaveOccurred()) + Expect(c.Len()).To(Equal(5)) + }) + + It("generates random length destination connection IDs", func() { + var has8ByteConnID, has18ByteConnID bool + for i := 0; i < 1000; i++ { + c, err := GenerateDestinationConnectionID() + Expect(err).ToNot(HaveOccurred()) + Expect(c.Len()).To(BeNumerically(">=", 8)) + Expect(c.Len()).To(BeNumerically("<=", 18)) + if c.Len() == 8 { + has8ByteConnID = true + } + if c.Len() == 18 { + has18ByteConnID = true + } + } + Expect(has8ByteConnID).To(BeTrue()) + Expect(has18ByteConnID).To(BeTrue()) + }) + It("says if connection IDs are equal", func() { c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} diff --git a/server_tls.go b/server_tls.go index b88372429..5d303b0c8 100644 --- a/server_tls.go +++ b/server_tls.go @@ -194,11 +194,15 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, StreamID: version.CryptoStreamID(), Data: bc.GetDataForWriting(), } + srcConnID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC) + if err != nil { + return nil, nil, err + } replyHdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeRetry, DestConnectionID: hdr.SrcConnectionID, - SrcConnectionID: hdr.DestConnectionID, + SrcConnectionID: srcConnID, PayloadLen: f.Length(version) + protocol.ByteCount(aead.Overhead()), PacketNumber: hdr.PacketNumber, // echo the client's packet number PacketNumberLen: hdr.PacketNumberLen, @@ -224,7 +228,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State()) } params := <-paramsChan - connID, err := protocol.GenerateConnectionID() + connID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC) if err != nil { return nil, nil, err } diff --git a/server_tls_test.go b/server_tls_test.go index b0e5af1a6..e7c804d1e 100644 --- a/server_tls_test.go +++ b/server_tls_test.go @@ -71,7 +71,7 @@ var _ = Describe("Stateless TLS handling", func() { return hdr, data } - unpackPacket := func(data []byte) (*wire.Header, []byte) { + unpackPacket := func(data []byte, clientDestConnID protocol.ConnectionID) (*wire.Header, []byte) { r := bytes.NewReader(conn.dataWritten.Bytes()) iHdr, err := wire.ParseInvariantHeader(r, 0) Expect(err).ToNot(HaveOccurred()) @@ -80,7 +80,7 @@ var _ = Describe("Stateless TLS handling", func() { hdr.Raw = data[:len(data)-r.Len()] var payload []byte if r.Len() > 0 { - aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, hdr.SrcConnectionID, protocol.VersionTLS) + aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, clientDestConnID, protocol.VersionTLS) Expect(err).ToNot(HaveOccurred()) payload, err = aead.Open(nil, data[len(data)-r.Len():], hdr.PacketNumber, hdr.Raw) Expect(err).ToNot(HaveOccurred()) @@ -97,7 +97,7 @@ var _ = Describe("Stateless TLS handling", func() { } server.HandleInitial(nil, hdr, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize)) Expect(conn.dataWritten.Len()).ToNot(BeZero()) - replyHdr, _ := unpackPacket(conn.dataWritten.Bytes()) + replyHdr, _ := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID) Expect(replyHdr.IsVersionNegotiation).To(BeTrue()) Expect(sessionChan).ToNot(Receive()) }) @@ -134,9 +134,9 @@ var _ = Describe("Stateless TLS handling", func() { hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")}) server.HandleInitial(nil, hdr, data) Expect(conn.dataWritten.Len()).ToNot(BeZero()) - replyHdr, payload := unpackPacket(conn.dataWritten.Bytes()) + replyHdr, payload := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID) Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) - Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) Expect(replyHdr.PayloadLen).To(BeEquivalentTo(len(payload) + 16 /* AEAD overhead */)) Expect(sessionChan).ToNot(Receive()) @@ -187,7 +187,7 @@ var _ = Describe("Stateless TLS handling", func() { // the Handshake packet is written by the session Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty()) // unpack the packet to check that it actually contains a CONNECTION_CLOSE - replyHdr, data := unpackPacket(conn.dataWritten.Bytes()) + replyHdr, data := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID) Expect(replyHdr.Type).To(Equal(protocol.PacketTypeHandshake)) Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))