forked from quic-go/quic-go
use a random length destination connection ID on the Initial packet
The destination connection ID on the Initial packet must be at least 8 bytes long. By using all valid values, we make sure that the everything works correctly. The server chooses a new connection ID with the Retry or Handshake packet it sends, so the overhead of this is negligible.
This commit is contained in:
47
client.go
47
client.go
@@ -140,23 +140,13 @@ func newClient(
|
||||
) (*client, error) {
|
||||
clientConfig := populateClientConfig(config)
|
||||
version := clientConfig.Versions[0]
|
||||
srcConnID, err := generateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
destConnID := srcConnID
|
||||
if version.UsesTLS() {
|
||||
destConnID, err = generateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var hostname string
|
||||
if tlsConf != nil {
|
||||
hostname = tlsConf.ServerName
|
||||
}
|
||||
if hostname == "" {
|
||||
var err error
|
||||
hostname, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -175,10 +165,8 @@ func newClient(
|
||||
if closeCallback != nil {
|
||||
onClose = closeCallback
|
||||
}
|
||||
return &client{
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
@@ -186,7 +174,8 @@ func newClient(
|
||||
handshakeChan: make(chan struct{}),
|
||||
closeCallback: onClose,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}, nil
|
||||
}
|
||||
return c, c.generateConnectionIDs()
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
@@ -243,6 +232,23 @@ func populateClientConfig(config *Config) *Config {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) generateConnectionIDs() error {
|
||||
srcConnID, err := generateConnectionID(protocol.ConnectionIDLenGQUIC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destConnID := srcConnID
|
||||
if c.version.UsesTLS() {
|
||||
destConnID, err = protocol.GenerateDestinationConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.srcConnID = srcConnID
|
||||
c.destConnID = destConnID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
|
||||
@@ -506,15 +512,8 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||
// switch to negotiated version
|
||||
c.initialVersion = c.version
|
||||
c.version = newVersion
|
||||
var err error
|
||||
c.destConnID, err = generateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// in gQUIC, there's only one connection ID
|
||||
if !c.version.UsesTLS() {
|
||||
c.srcConnID = c.destConnID
|
||||
}
|
||||
c.generateConnectionIDs()
|
||||
|
||||
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
|
||||
c.session.Close(errCloseSessionForNewVersion)
|
||||
return nil
|
||||
|
||||
@@ -81,11 +81,11 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
|
||||
Context("Dialing", func() {
|
||||
var origGenerateConnectionID func() (protocol.ConnectionID, error)
|
||||
var origGenerateConnectionID func(int) (protocol.ConnectionID, error)
|
||||
|
||||
BeforeEach(func() {
|
||||
origGenerateConnectionID = generateConnectionID
|
||||
generateConnectionID = func() (protocol.ConnectionID, error) {
|
||||
generateConnectionID = func(int) (protocol.ConnectionID, error) {
|
||||
return connID, nil
|
||||
}
|
||||
})
|
||||
|
||||
@@ -10,15 +10,28 @@ import (
|
||||
// A ConnectionID in QUIC
|
||||
type ConnectionID []byte
|
||||
|
||||
const maxConnectionIDLen = 18
|
||||
|
||||
// GenerateConnectionID generates a connection ID using cryptographic random
|
||||
func GenerateConnectionID() (ConnectionID, error) {
|
||||
b := make([]byte, ConnectionIDLenGQUIC)
|
||||
func GenerateConnectionID(len int) (ConnectionID, error) {
|
||||
b := make([]byte, len)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ConnectionID(b), nil
|
||||
}
|
||||
|
||||
// GenerateDestinationConnectionID generates a connection ID for the Initial packet.
|
||||
// It uses a length randomly chosen between 8 and 18 bytes.
|
||||
func GenerateDestinationConnectionID() (ConnectionID, error) {
|
||||
r := make([]byte, 1)
|
||||
if _, err := rand.Read(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
|
||||
return GenerateConnectionID(len)
|
||||
}
|
||||
|
||||
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
|
||||
// It returns io.EOF if there are not enough bytes to read.
|
||||
func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) {
|
||||
|
||||
@@ -10,14 +10,38 @@ import (
|
||||
|
||||
var _ = Describe("Connection ID generation", func() {
|
||||
It("generates random connection IDs", func() {
|
||||
c1, err := GenerateConnectionID()
|
||||
c1, err := GenerateConnectionID(8)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c1).ToNot(BeZero())
|
||||
c2, err := GenerateConnectionID()
|
||||
c2, err := GenerateConnectionID(8)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c1).ToNot(Equal(c2))
|
||||
})
|
||||
|
||||
It("generates connection IDs with the requested length", func() {
|
||||
c, err := GenerateConnectionID(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c.Len()).To(Equal(5))
|
||||
})
|
||||
|
||||
It("generates random length destination connection IDs", func() {
|
||||
var has8ByteConnID, has18ByteConnID bool
|
||||
for i := 0; i < 1000; i++ {
|
||||
c, err := GenerateDestinationConnectionID()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c.Len()).To(BeNumerically(">=", 8))
|
||||
Expect(c.Len()).To(BeNumerically("<=", 18))
|
||||
if c.Len() == 8 {
|
||||
has8ByteConnID = true
|
||||
}
|
||||
if c.Len() == 18 {
|
||||
has18ByteConnID = true
|
||||
}
|
||||
}
|
||||
Expect(has8ByteConnID).To(BeTrue())
|
||||
Expect(has18ByteConnID).To(BeTrue())
|
||||
})
|
||||
|
||||
It("says if connection IDs are equal", func() {
|
||||
c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
|
||||
@@ -194,11 +194,15 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
|
||||
StreamID: version.CryptoStreamID(),
|
||||
Data: bc.GetDataForWriting(),
|
||||
}
|
||||
srcConnID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
replyHdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
DestConnectionID: hdr.SrcConnectionID,
|
||||
SrcConnectionID: hdr.DestConnectionID,
|
||||
SrcConnectionID: srcConnID,
|
||||
PayloadLen: f.Length(version) + protocol.ByteCount(aead.Overhead()),
|
||||
PacketNumber: hdr.PacketNumber, // echo the client's packet number
|
||||
PacketNumberLen: hdr.PacketNumberLen,
|
||||
@@ -224,7 +228,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
|
||||
return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State())
|
||||
}
|
||||
params := <-paramsChan
|
||||
connID, err := protocol.GenerateConnectionID()
|
||||
connID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ var _ = Describe("Stateless TLS handling", func() {
|
||||
return hdr, data
|
||||
}
|
||||
|
||||
unpackPacket := func(data []byte) (*wire.Header, []byte) {
|
||||
unpackPacket := func(data []byte, clientDestConnID protocol.ConnectionID) (*wire.Header, []byte) {
|
||||
r := bytes.NewReader(conn.dataWritten.Bytes())
|
||||
iHdr, err := wire.ParseInvariantHeader(r, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -80,7 +80,7 @@ var _ = Describe("Stateless TLS handling", func() {
|
||||
hdr.Raw = data[:len(data)-r.Len()]
|
||||
var payload []byte
|
||||
if r.Len() > 0 {
|
||||
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, hdr.SrcConnectionID, protocol.VersionTLS)
|
||||
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, clientDestConnID, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
payload, err = aead.Open(nil, data[len(data)-r.Len():], hdr.PacketNumber, hdr.Raw)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -97,7 +97,7 @@ var _ = Describe("Stateless TLS handling", func() {
|
||||
}
|
||||
server.HandleInitial(nil, hdr, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize))
|
||||
Expect(conn.dataWritten.Len()).ToNot(BeZero())
|
||||
replyHdr, _ := unpackPacket(conn.dataWritten.Bytes())
|
||||
replyHdr, _ := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID)
|
||||
Expect(replyHdr.IsVersionNegotiation).To(BeTrue())
|
||||
Expect(sessionChan).ToNot(Receive())
|
||||
})
|
||||
@@ -134,9 +134,9 @@ var _ = Describe("Stateless TLS handling", func() {
|
||||
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
|
||||
server.HandleInitial(nil, hdr, data)
|
||||
Expect(conn.dataWritten.Len()).ToNot(BeZero())
|
||||
replyHdr, payload := unpackPacket(conn.dataWritten.Bytes())
|
||||
replyHdr, payload := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID)
|
||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
||||
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
|
||||
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
||||
Expect(replyHdr.PayloadLen).To(BeEquivalentTo(len(payload) + 16 /* AEAD overhead */))
|
||||
Expect(sessionChan).ToNot(Receive())
|
||||
@@ -187,7 +187,7 @@ var _ = Describe("Stateless TLS handling", func() {
|
||||
// the Handshake packet is written by the session
|
||||
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
|
||||
// unpack the packet to check that it actually contains a CONNECTION_CLOSE
|
||||
replyHdr, data := unpackPacket(conn.dataWritten.Bytes())
|
||||
replyHdr, data := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID)
|
||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeHandshake))
|
||||
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
||||
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
||||
|
||||
Reference in New Issue
Block a user