use a random length destination connection ID on the Initial packet

The destination connection ID on the Initial packet must be at least 8
bytes long. By using all valid values, we make sure that the everything
works correctly. The server chooses a new connection ID with the Retry
or Handshake packet it sends, so the overhead of this is negligible.
This commit is contained in:
Marten Seemann
2018-07-01 13:48:47 +07:00
parent 0bd7e744ff
commit 73f7636537
6 changed files with 78 additions and 38 deletions

View File

@@ -140,23 +140,13 @@ func newClient(
) (*client, error) {
clientConfig := populateClientConfig(config)
version := clientConfig.Versions[0]
srcConnID, err := generateConnectionID()
if err != nil {
return nil, err
}
destConnID := srcConnID
if version.UsesTLS() {
destConnID, err = generateConnectionID()
if err != nil {
return nil, err
}
}
var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
var err error
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
@@ -175,10 +165,8 @@ func newClient(
if closeCallback != nil {
onClose = closeCallback
}
return &client{
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
srcConnID: srcConnID,
destConnID: destConnID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
@@ -186,7 +174,8 @@ func newClient(
handshakeChan: make(chan struct{}),
closeCallback: onClose,
logger: utils.DefaultLogger.WithPrefix("client"),
}, nil
}
return c, c.generateConnectionIDs()
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
@@ -243,6 +232,23 @@ func populateClientConfig(config *Config) *Config {
}
}
func (c *client) generateConnectionIDs() error {
srcConnID, err := generateConnectionID(protocol.ConnectionIDLenGQUIC)
if err != nil {
return err
}
destConnID := srcConnID
if c.version.UsesTLS() {
destConnID, err = protocol.GenerateDestinationConnectionID()
if err != nil {
return err
}
}
c.srcConnID = srcConnID
c.destConnID = destConnID
return nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
@@ -506,15 +512,8 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
// switch to negotiated version
c.initialVersion = c.version
c.version = newVersion
var err error
c.destConnID, err = generateConnectionID()
if err != nil {
return err
}
// in gQUIC, there's only one connection ID
if !c.version.UsesTLS() {
c.srcConnID = c.destConnID
}
c.generateConnectionIDs()
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.Close(errCloseSessionForNewVersion)
return nil

View File

@@ -81,11 +81,11 @@ var _ = Describe("Client", func() {
})
Context("Dialing", func() {
var origGenerateConnectionID func() (protocol.ConnectionID, error)
var origGenerateConnectionID func(int) (protocol.ConnectionID, error)
BeforeEach(func() {
origGenerateConnectionID = generateConnectionID
generateConnectionID = func() (protocol.ConnectionID, error) {
generateConnectionID = func(int) (protocol.ConnectionID, error) {
return connID, nil
}
})

View File

@@ -10,15 +10,28 @@ import (
// A ConnectionID in QUIC
type ConnectionID []byte
const maxConnectionIDLen = 18
// GenerateConnectionID generates a connection ID using cryptographic random
func GenerateConnectionID() (ConnectionID, error) {
b := make([]byte, ConnectionIDLenGQUIC)
func GenerateConnectionID(len int) (ConnectionID, error) {
b := make([]byte, len)
if _, err := rand.Read(b); err != nil {
return nil, err
}
return ConnectionID(b), nil
}
// GenerateDestinationConnectionID generates a connection ID for the Initial packet.
// It uses a length randomly chosen between 8 and 18 bytes.
func GenerateDestinationConnectionID() (ConnectionID, error) {
r := make([]byte, 1)
if _, err := rand.Read(r); err != nil {
return nil, err
}
len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
return GenerateConnectionID(len)
}
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
// It returns io.EOF if there are not enough bytes to read.
func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) {

View File

@@ -10,14 +10,38 @@ import (
var _ = Describe("Connection ID generation", func() {
It("generates random connection IDs", func() {
c1, err := GenerateConnectionID()
c1, err := GenerateConnectionID(8)
Expect(err).ToNot(HaveOccurred())
Expect(c1).ToNot(BeZero())
c2, err := GenerateConnectionID()
c2, err := GenerateConnectionID(8)
Expect(err).ToNot(HaveOccurred())
Expect(c1).ToNot(Equal(c2))
})
It("generates connection IDs with the requested length", func() {
c, err := GenerateConnectionID(5)
Expect(err).ToNot(HaveOccurred())
Expect(c.Len()).To(Equal(5))
})
It("generates random length destination connection IDs", func() {
var has8ByteConnID, has18ByteConnID bool
for i := 0; i < 1000; i++ {
c, err := GenerateDestinationConnectionID()
Expect(err).ToNot(HaveOccurred())
Expect(c.Len()).To(BeNumerically(">=", 8))
Expect(c.Len()).To(BeNumerically("<=", 18))
if c.Len() == 8 {
has8ByteConnID = true
}
if c.Len() == 18 {
has18ByteConnID = true
}
}
Expect(has8ByteConnID).To(BeTrue())
Expect(has18ByteConnID).To(BeTrue())
})
It("says if connection IDs are equal", func() {
c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}

View File

@@ -194,11 +194,15 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
StreamID: version.CryptoStreamID(),
Data: bc.GetDataForWriting(),
}
srcConnID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC)
if err != nil {
return nil, nil, err
}
replyHdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
DestConnectionID: hdr.SrcConnectionID,
SrcConnectionID: hdr.DestConnectionID,
SrcConnectionID: srcConnID,
PayloadLen: f.Length(version) + protocol.ByteCount(aead.Overhead()),
PacketNumber: hdr.PacketNumber, // echo the client's packet number
PacketNumberLen: hdr.PacketNumberLen,
@@ -224,7 +228,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State())
}
params := <-paramsChan
connID, err := protocol.GenerateConnectionID()
connID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC)
if err != nil {
return nil, nil, err
}

View File

@@ -71,7 +71,7 @@ var _ = Describe("Stateless TLS handling", func() {
return hdr, data
}
unpackPacket := func(data []byte) (*wire.Header, []byte) {
unpackPacket := func(data []byte, clientDestConnID protocol.ConnectionID) (*wire.Header, []byte) {
r := bytes.NewReader(conn.dataWritten.Bytes())
iHdr, err := wire.ParseInvariantHeader(r, 0)
Expect(err).ToNot(HaveOccurred())
@@ -80,7 +80,7 @@ var _ = Describe("Stateless TLS handling", func() {
hdr.Raw = data[:len(data)-r.Len()]
var payload []byte
if r.Len() > 0 {
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, hdr.SrcConnectionID, protocol.VersionTLS)
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, clientDestConnID, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
payload, err = aead.Open(nil, data[len(data)-r.Len():], hdr.PacketNumber, hdr.Raw)
Expect(err).ToNot(HaveOccurred())
@@ -97,7 +97,7 @@ var _ = Describe("Stateless TLS handling", func() {
}
server.HandleInitial(nil, hdr, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize))
Expect(conn.dataWritten.Len()).ToNot(BeZero())
replyHdr, _ := unpackPacket(conn.dataWritten.Bytes())
replyHdr, _ := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID)
Expect(replyHdr.IsVersionNegotiation).To(BeTrue())
Expect(sessionChan).ToNot(Receive())
})
@@ -134,9 +134,9 @@ var _ = Describe("Stateless TLS handling", func() {
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
server.HandleInitial(nil, hdr, data)
Expect(conn.dataWritten.Len()).ToNot(BeZero())
replyHdr, payload := unpackPacket(conn.dataWritten.Bytes())
replyHdr, payload := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(replyHdr.PayloadLen).To(BeEquivalentTo(len(payload) + 16 /* AEAD overhead */))
Expect(sessionChan).ToNot(Receive())
@@ -187,7 +187,7 @@ var _ = Describe("Stateless TLS handling", func() {
// the Handshake packet is written by the session
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
// unpack the packet to check that it actually contains a CONNECTION_CLOSE
replyHdr, data := unpackPacket(conn.dataWritten.Bytes())
replyHdr, data := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeHandshake))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))