Merge pull request #1427 from lucas-clemente/connection-id-len

make the connection ID length configurable
This commit is contained in:
Marten Seemann
2018-07-03 18:58:08 +07:00
committed by GitHub
22 changed files with 298 additions and 176 deletions

View File

@@ -1,5 +1,9 @@
# Changelog
## v0.9.0 (unreleased)
- Add a `quic.Config` option for the length of the connection ID (for IETF QUIC).
## v0.8.0 (2018-06-26)
- Add support for unidirectional streams (for IETF QUIC).

View File

@@ -74,6 +74,7 @@ func DialAddrContext(
tlsConf *tls.Config,
config *Config,
) (Session, error) {
config = populateClientConfig(config, false)
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
@@ -115,8 +116,12 @@ func DialContext(
tlsConf *tls.Config,
config *Config,
) (Session, error) {
config = populateClientConfig(config, true)
multiplexer := getClientMultiplexer()
manager := multiplexer.AddConn(pconn)
manager, err := multiplexer.AddConn(pconn, config.ConnectionIDLength)
if err != nil {
return nil, err
}
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove)
if err != nil {
return nil, err
@@ -138,25 +143,12 @@ func newClient(
host string,
closeCallback func(protocol.ConnectionID),
) (*client, error) {
clientConfig := populateClientConfig(config)
version := clientConfig.Versions[0]
srcConnID, err := generateConnectionID()
if err != nil {
return nil, err
}
destConnID := srcConnID
if version.UsesTLS() {
destConnID, err = generateConnectionID()
if err != nil {
return nil, err
}
}
var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
var err error
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
@@ -175,23 +167,22 @@ func newClient(
if closeCallback != nil {
onClose = closeCallback
}
return &client{
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
srcConnID: srcConnID,
destConnID: destConnID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: version,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
closeCallback: onClose,
logger: utils.DefaultLogger.WithPrefix("client"),
}, nil
}
return c, c.generateConnectionIDs()
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config) *Config {
func populateClientConfig(config *Config, onPacketConn bool) *Config {
if config == nil {
config = &Config{}
}
@@ -229,12 +220,17 @@ func populateClientConfig(config *Config) *Config {
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
connIDLen := config.ConnectionIDLength
if connIDLen == 0 && onPacketConn {
connIDLen = protocol.DefaultConnectionIDLength
}
return &Config{
Versions: versions,
HandshakeTimeout: handshakeTimeout,
IdleTimeout: idleTimeout,
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
ConnectionIDLength: connIDLen,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
MaxIncomingStreams: maxIncomingStreams,
@@ -243,6 +239,27 @@ func populateClientConfig(config *Config) *Config {
}
}
func (c *client) generateConnectionIDs() error {
connIDLen := protocol.ConnectionIDLenGQUIC
if c.version.UsesTLS() {
connIDLen = c.config.ConnectionIDLength
}
srcConnID, err := generateConnectionID(connIDLen)
if err != nil {
return err
}
destConnID := srcConnID
if c.version.UsesTLS() {
destConnID, err = protocol.GenerateDestinationConnectionID()
if err != nil {
return err
}
}
c.srcConnID = srcConnID
c.destConnID = destConnID
return nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
@@ -364,7 +381,7 @@ func (c *client) handleRead(remoteAddr net.Addr, packet []byte) {
rcvTime := time.Now()
r := bytes.NewReader(packet)
iHdr, err := wire.ParseInvariantHeader(r)
iHdr, err := wire.ParseInvariantHeader(r, c.config.ConnectionIDLength)
// drop the packet if we can't parse the header
if err != nil {
c.logger.Errorf("error parsing invariant header: %s", err)
@@ -506,15 +523,8 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
// switch to negotiated version
c.initialVersion = c.version
c.version = newVersion
var err error
c.destConnID, err = generateConnectionID()
if err != nil {
return err
}
// in gQUIC, there's only one connection ID
if !c.version.UsesTLS() {
c.srcConnID = c.destConnID
}
c.generateConnectionIDs()
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.Close(errCloseSessionForNewVersion)
return nil

View File

@@ -3,6 +3,7 @@ package quic
import (
"bytes"
"errors"
"fmt"
"net"
"strings"
"sync"
@@ -19,16 +20,21 @@ var (
)
type multiplexer interface {
AddConn(net.PacketConn) packetHandlerManager
AddConn(net.PacketConn, int) (packetHandlerManager, error)
AddHandler(net.PacketConn, protocol.ConnectionID, packetHandler) error
}
type connManager struct {
connIDLen int
manager packetHandlerManager
}
// The clientMultiplexer listens on multiple net.PacketConns and dispatches
// incoming packets to the session handler.
type clientMultiplexer struct {
mutex sync.Mutex
conns map[net.PacketConn]packetHandlerManager
conns map[net.PacketConn]connManager
newPacketHandlerManager func() packetHandlerManager // so it can be replaced in the tests
logger utils.Logger
@@ -39,7 +45,7 @@ var _ multiplexer = &clientMultiplexer{}
func getClientMultiplexer() multiplexer {
clientMuxerOnce.Do(func() {
clientMuxer = &clientMultiplexer{
conns: make(map[net.PacketConn]packetHandlerManager),
conns: make(map[net.PacketConn]connManager),
logger: utils.DefaultLogger.WithPrefix("client muxer"),
newPacketHandlerManager: newPacketHandlerMap,
}
@@ -47,30 +53,34 @@ func getClientMultiplexer() multiplexer {
return clientMuxer
}
func (m *clientMultiplexer) AddConn(c net.PacketConn) packetHandlerManager {
func (m *clientMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandlerManager, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
sessions, ok := m.conns[c]
p, ok := m.conns[c]
if !ok {
sessions = m.newPacketHandlerManager()
m.conns[c] = sessions
manager := m.newPacketHandlerManager()
p = connManager{connIDLen: connIDLen, manager: manager}
m.conns[c] = p
// If we didn't know this packet conn before, listen for incoming packets
// and dispatch them to the right sessions.
go m.listen(c, sessions)
go m.listen(c, p)
}
return sessions
if p.connIDLen != connIDLen {
return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen)
}
return p.manager, nil
}
func (m *clientMultiplexer) AddHandler(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) error {
sessions, ok := m.conns[c]
p, ok := m.conns[c]
if !ok {
return errors.New("unknown packet conn %s")
}
sessions.Add(connID, handler)
p.manager.Add(connID, handler)
return nil
}
func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManager) {
func (m *clientMultiplexer) listen(c net.PacketConn, p connManager) {
for {
data := *getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize]
@@ -79,7 +89,7 @@ func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManag
n, addr, err := c.ReadFrom(data)
if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
sessions.Close(err)
p.manager.Close(err)
}
return
}
@@ -87,13 +97,13 @@ func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManag
rcvTime := time.Now()
r := bytes.NewReader(data)
iHdr, err := wire.ParseInvariantHeader(r)
iHdr, err := wire.ParseInvariantHeader(r, p.connIDLen)
// drop the packet if we can't parse the header
if err != nil {
m.logger.Debugf("error parsing invariant header from %s: %s", addr, err)
continue
}
client, ok := sessions.Get(iHdr.DestConnectionID)
client, ok := p.manager.Get(iHdr.DestConnectionID)
if !ok {
m.logger.Debugf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
continue

View File

@@ -29,11 +29,12 @@ var _ = Describe("Client Multiplexer", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
packetHandler := NewMockQuicSession(mockCtrl)
handledPacket := make(chan struct{})
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(_ *receivedPacket) {
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.header.DestConnectionID).To(Equal(connID))
close(handledPacket)
})
packetHandler.EXPECT().GetVersion()
getClientMultiplexer().AddConn(conn)
getClientMultiplexer().AddConn(conn, 8)
err := getClientMultiplexer().AddHandler(conn, connID, packetHandler)
Expect(err).ToNot(HaveOccurred())
conn.dataToRead <- getPacket(connID)
@@ -43,6 +44,14 @@ var _ = Describe("Client Multiplexer", func() {
close(conn.dataToRead)
})
It("errors when adding an existing conn with a different connection ID length", func() {
conn := newMockPacketConn()
_, err := getClientMultiplexer().AddConn(conn, 5)
Expect(err).ToNot(HaveOccurred())
_, err = getClientMultiplexer().AddConn(conn, 6)
Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs"))
})
It("errors when adding a handler for an unknown conn", func() {
conn := newMockPacketConn()
err := getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4}, NewMockQuicSession(mockCtrl))
@@ -67,7 +76,7 @@ var _ = Describe("Client Multiplexer", func() {
close(handledPacket2)
})
packetHandler2.EXPECT().GetVersion()
getClientMultiplexer().AddConn(conn)
getClientMultiplexer().AddConn(conn, connID1.Len())
Expect(getClientMultiplexer().AddHandler(conn, connID1, packetHandler1)).To(Succeed())
Expect(getClientMultiplexer().AddHandler(conn, connID2, packetHandler2)).To(Succeed())
@@ -84,10 +93,10 @@ var _ = Describe("Client Multiplexer", func() {
It("drops unparseable packets", func() {
conn := newMockPacketConn()
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7}
conn.dataToRead <- []byte("invalid header")
packetHandler := NewMockQuicSession(mockCtrl)
getClientMultiplexer().AddConn(conn)
getClientMultiplexer().AddConn(conn, 7)
Expect(getClientMultiplexer().AddHandler(conn, connID, packetHandler)).To(Succeed())
time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
@@ -106,7 +115,7 @@ var _ = Describe("Client Multiplexer", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
done := make(chan struct{})
manager.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(done) }).Return(nil, true)
getClientMultiplexer().AddConn(conn)
getClientMultiplexer().AddConn(conn, 8)
conn.dataToRead <- getPacket(connID)
Eventually(done).Should(BeClosed())
// makes the listen go routine return
@@ -118,7 +127,7 @@ var _ = Describe("Client Multiplexer", func() {
conn := newMockPacketConn()
conn.dataToRead <- getPacket(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})
packetHandler := NewMockQuicSession(mockCtrl)
getClientMultiplexer().AddConn(conn)
getClientMultiplexer().AddConn(conn, 8)
Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, packetHandler)).To(Succeed())
time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet
// makes the listen go routine return
@@ -135,7 +144,7 @@ var _ = Describe("Client Multiplexer", func() {
packetHandler.EXPECT().Close(testErr).Do(func(error) {
close(done)
})
getClientMultiplexer().AddConn(conn)
getClientMultiplexer().AddConn(conn, 8)
Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)).To(Succeed())
Eventually(done).Should(BeClosed())
})

View File

@@ -81,11 +81,11 @@ var _ = Describe("Client", func() {
})
Context("Dialing", func() {
var origGenerateConnectionID func() (protocol.ConnectionID, error)
var origGenerateConnectionID func(int) (protocol.ConnectionID, error)
BeforeEach(func() {
origGenerateConnectionID = generateConnectionID
generateConnectionID = func() (protocol.ConnectionID, error) {
generateConnectionID = func(int) (protocol.ConnectionID, error) {
return connID, nil
}
})
@@ -147,7 +147,7 @@ var _ = Describe("Client", func() {
It("returns after the handshake is complete", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
run := make(chan struct{})
@@ -176,7 +176,7 @@ var _ = Describe("Client", func() {
It("returns an error that occurs while waiting for the connection to become secure", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
testErr := errors.New("early handshake error")
@@ -203,7 +203,7 @@ var _ = Describe("Client", func() {
It("closes the session when the context is canceled", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
sessionRunning := make(chan struct{})
@@ -243,7 +243,7 @@ var _ = Describe("Client", func() {
It("removes closed sessions from the multiplexer", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Remove(connID)
mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
var runner sessionRunner
@@ -279,18 +279,20 @@ var _ = Describe("Client", func() {
RequestConnectionIDOmission: true,
MaxIncomingStreams: 1234,
MaxIncomingUniStreams: 4321,
ConnectionIDLength: 13,
}
c := populateClientConfig(config)
c := populateClientConfig(config, false)
Expect(c.HandshakeTimeout).To(Equal(1337 * time.Minute))
Expect(c.IdleTimeout).To(Equal(42 * time.Hour))
Expect(c.RequestConnectionIDOmission).To(BeTrue())
Expect(c.MaxIncomingStreams).To(Equal(1234))
Expect(c.MaxIncomingUniStreams).To(Equal(4321))
Expect(c.ConnectionIDLength).To(Equal(13))
})
It("errors when the Config contains an invalid version", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
version := protocol.VersionNumber(0x1234)
_, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}})
@@ -302,7 +304,7 @@ var _ = Describe("Client", func() {
MaxIncomingStreams: -1,
MaxIncomingUniStreams: 4321,
}
c := populateClientConfig(config)
c := populateClientConfig(config, false)
Expect(c.MaxIncomingStreams).To(BeZero())
Expect(c.MaxIncomingUniStreams).To(Equal(4321))
})
@@ -312,13 +314,25 @@ var _ = Describe("Client", func() {
MaxIncomingStreams: 1234,
MaxIncomingUniStreams: -1,
}
c := populateClientConfig(config)
c := populateClientConfig(config, false)
Expect(c.MaxIncomingStreams).To(Equal(1234))
Expect(c.MaxIncomingUniStreams).To(BeZero())
})
It("uses 0-byte connection IDs when dialing an address", func() {
config := &Config{}
c := populateClientConfig(config, false)
Expect(c.ConnectionIDLength).To(BeZero())
})
It("doesn't use 0-byte connection IDs when dialing an address", func() {
config := &Config{}
c := populateClientConfig(config, true)
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
})
It("fills in default values if options are not set in the Config", func() {
c := populateClientConfig(&Config{})
c := populateClientConfig(&Config{}, false)
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
Expect(c.HandshakeTimeout).To(Equal(protocol.DefaultHandshakeTimeout))
Expect(c.IdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
@@ -329,7 +343,7 @@ var _ = Describe("Client", func() {
Context("gQUIC", func() {
It("errors if it can't create a session", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
testErr := errors.New("error creating session")
@@ -355,7 +369,7 @@ var _ = Describe("Client", func() {
Context("IETF QUIC", func() {
It("creates new TLS sessions with the right parameters", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
@@ -411,7 +425,7 @@ var _ = Describe("Client", func() {
It("returns an error that occurs during version negotiation", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
testErr := errors.New("early handshake error")
@@ -568,6 +582,7 @@ var _ = Describe("Client", func() {
})
It("drops version negotiation packets that contain the offered version", func() {
cl.config = &Config{}
ver := cl.version
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver}))
Expect(cl.version).To(Equal(ver))
@@ -581,6 +596,7 @@ var _ = Describe("Client", func() {
})
It("ignores packets with an invalid public header", func() {
cl.config = &Config{}
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
cl.handleRead(addr, []byte("invalid packet"))
})
@@ -682,7 +698,7 @@ var _ = Describe("Client", func() {
It("creates new gQUIC sessions with the right parameters", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
config := &Config{Versions: protocol.SupportedVersions}
@@ -723,7 +739,7 @@ var _ = Describe("Client", func() {
It("creates a new session when the server performs a retry", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
@@ -757,7 +773,7 @@ var _ = Describe("Client", func() {
It("only accepts one Retry packet", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any())
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
@@ -866,7 +882,7 @@ var _ = Describe("Client", func() {
pr = wire.WritePublicReset(cl.destConnID, 1, 0)
r := bytes.NewReader(pr)
iHdr, err := wire.ParseInvariantHeader(r)
iHdr, err := wire.ParseInvariantHeader(r, 0)
Expect(err).ToNot(HaveOccurred())
hdr, err = iHdr.Parse(r, protocol.PerspectiveServer, versionGQUICFrames)
Expect(err).ToNot(HaveOccurred())

View File

@@ -165,6 +165,13 @@ type Config struct {
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client.
RequestConnectionIDOmission bool
// The length of the connection ID in bytes. Only valid for IETF QUIC.
// It can be 0, or any value between 4 and 18.
// If not set, the interpretation depends on where the Config is used:
// If used for dialing an address, a 0 byte connection ID will be used.
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
ConnectionIDLength int
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 10 seconds.

View File

@@ -10,15 +10,28 @@ import (
// A ConnectionID in QUIC
type ConnectionID []byte
const maxConnectionIDLen = 18
// GenerateConnectionID generates a connection ID using cryptographic random
func GenerateConnectionID() (ConnectionID, error) {
b := make([]byte, ConnectionIDLen)
func GenerateConnectionID(len int) (ConnectionID, error) {
b := make([]byte, len)
if _, err := rand.Read(b); err != nil {
return nil, err
}
return ConnectionID(b), nil
}
// GenerateDestinationConnectionID generates a connection ID for the Initial packet.
// It uses a length randomly chosen between 8 and 18 bytes.
func GenerateDestinationConnectionID() (ConnectionID, error) {
r := make([]byte, 1)
if _, err := rand.Read(r); err != nil {
return nil, err
}
len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
return GenerateConnectionID(len)
}
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
// It returns io.EOF if there are not enough bytes to read.
func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) {

View File

@@ -10,14 +10,38 @@ import (
var _ = Describe("Connection ID generation", func() {
It("generates random connection IDs", func() {
c1, err := GenerateConnectionID()
c1, err := GenerateConnectionID(8)
Expect(err).ToNot(HaveOccurred())
Expect(c1).ToNot(BeZero())
c2, err := GenerateConnectionID()
c2, err := GenerateConnectionID(8)
Expect(err).ToNot(HaveOccurred())
Expect(c1).ToNot(Equal(c2))
})
It("generates connection IDs with the requested length", func() {
c, err := GenerateConnectionID(5)
Expect(err).ToNot(HaveOccurred())
Expect(c.Len()).To(Equal(5))
})
It("generates random length destination connection IDs", func() {
var has8ByteConnID, has18ByteConnID bool
for i := 0; i < 1000; i++ {
c, err := GenerateDestinationConnectionID()
Expect(err).ToNot(HaveOccurred())
Expect(c.Len()).To(BeNumerically(">=", 8))
Expect(c.Len()).To(BeNumerically("<=", 18))
if c.Len() == 8 {
has8ByteConnID = true
}
if c.Len() == 18 {
has18ByteConnID = true
}
}
Expect(has8ByteConnID).To(BeTrue())
Expect(has18ByteConnID).To(BeTrue())
})
It("says if connection IDs are equal", func() {
c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}

View File

@@ -82,3 +82,9 @@ const MinInitialPacketSize = 1200
// * one failure due to an incorrect or missing source-address token
// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
const MaxClientHellos = 3
// ConnectionIDLenGQUIC is the length of the source Connection ID used on gQUIC QUIC packets.
const ConnectionIDLenGQUIC = 8
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
const MinConnectionIDLenInitial = 8

View File

@@ -146,11 +146,6 @@ const MaxAckFrameSize ByteCount = 1000
// Example: For a packet pacing delay of 20 microseconds, we would send 5 packets at once, wait for 100 microseconds, and so forth.
const MinPacingDelay time.Duration = 100 * time.Microsecond
// ConnectionIDLen is the length of the source Connection ID used on IETF QUIC packets.
// The Short Header contains the connection ID, but not the length,
// so we need to know this value in advance (or encode it into the connection ID).
// TODO: make this configurable
const ConnectionIDLen = 8
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
const MinConnectionIDLenInitial = 8
// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections
// if no other value is configured.
const DefaultConnectionIDLength = 4

View File

@@ -56,9 +56,6 @@ func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, version proto
// TODO: add support for the key phase
func (h *Header) writeLongHeader(b *bytes.Buffer) error {
if h.SrcConnectionID.Len() != protocol.ConnectionIDLen {
return fmt.Errorf("Header: source connection ID must be %d bytes, is %d", protocol.ConnectionIDLen, h.SrcConnectionID.Len())
}
b.WriteByte(byte(0x80 | h.Type))
utils.BigEndian.WriteUint32(b, uint32(h.Version))
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
@@ -174,7 +171,7 @@ func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) {
return 0, errPacketNumberLenNotSet
}
length += protocol.ByteCount(h.PacketNumberLen)
length += protocol.ByteCount(h.DestConnectionID.Len()) // if set, always 8 bytes
length += protocol.ByteCount(h.DestConnectionID.Len())
// Version Number in packets sent by the client
if h.VersionFlag {
length += 4

View File

@@ -21,7 +21,7 @@ type InvariantHeader struct {
}
// ParseInvariantHeader parses the version independent part of the header
func ParseInvariantHeader(b *bytes.Reader) (*InvariantHeader, error) {
func ParseInvariantHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*InvariantHeader, error) {
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
@@ -36,8 +36,15 @@ func ParseInvariantHeader(b *bytes.Reader) (*InvariantHeader, error) {
// In the IETF Short Header:
// * 0x8 it is the gQUIC Demultiplexing bit, and always 0.
// * 0x20 and 0x10 are always 1.
if typeByte&0x8 > 0 || typeByte&0x38 == 0x30 {
h.DestConnectionID, err = protocol.ReadConnectionID(b, 8)
var connIDLen int
if typeByte&0x8 > 0 { // Public Header containing a connection ID
connIDLen = 8
}
if typeByte&0x38 == 0x30 { // Short Header
connIDLen = shortHeaderConnIDLen
}
if connIDLen > 0 {
h.DestConnectionID, err = protocol.ReadConnectionID(b, connIDLen)
if err != nil {
return nil, err
}

View File

@@ -22,13 +22,13 @@ var _ = Describe("Header Parsing", func() {
Context("Version Negotiation Packets", func() {
It("parses", func() {
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1}
versions := []protocol.VersionNumber{0x22334455, 0x33445566}
data, err := ComposeVersionNegotiation(destConnID, srcConnID, versions)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
Expect(iHdr.DestConnectionID).To(Equal(destConnID))
Expect(iHdr.SrcConnectionID).To(Equal(srcConnID))
@@ -50,7 +50,7 @@ var _ = Describe("Header Parsing", func() {
data, err := ComposeVersionNegotiation(connID, connID, versions)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data[:len(data)-2])
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
_, err = iHdr.Parse(b, protocol.PerspectiveServer, versionIETFFrames)
Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket))
@@ -63,7 +63,7 @@ var _ = Describe("Header Parsing", func() {
Expect(err).ToNot(HaveOccurred())
// remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number
b := bytes.NewReader(data[:len(data)-8])
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
_, err = iHdr.Parse(b, protocol.PerspectiveServer, versionIETFFrames)
Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list"))
@@ -71,13 +71,13 @@ var _ = Describe("Header Parsing", func() {
})
Context("Long Headers", func() {
It("parses a long header", func() {
destConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}
srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x42, 0x42}
It("parses a Long Header", func() {
destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1}
srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
data := []byte{
0x80 ^ uint8(protocol.PacketTypeInitial),
0x1, 0x2, 0x3, 0x4, // version number
0x55, // connection ID lengths
0x61, // connection ID lengths
}
data = append(data, destConnID...)
data = append(data, srcConnID...)
@@ -86,7 +86,7 @@ var _ = Describe("Header Parsing", func() {
data = appendPacketNumber(data, 0xbeef, protocol.PacketNumberLen4)
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
Expect(iHdr.IsLongHeader).To(BeTrue())
Expect(iHdr.DestConnectionID).To(Equal(destConnID))
@@ -105,7 +105,7 @@ var _ = Describe("Header Parsing", func() {
Expect(b.Len()).To(BeZero())
})
It("parses a long header without a destination connection ID", func() {
It("parses a Long Header without a destination connection ID", func() {
data := []byte{
0x80 ^ uint8(protocol.PacketTypeInitial),
0x1, 0x2, 0x3, 0x4, // version number
@@ -115,13 +115,13 @@ var _ = Describe("Header Parsing", func() {
data = append(data, encodeVarInt(0x42)...) // payload length
data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...)
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
Expect(iHdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}))
Expect(iHdr.DestConnectionID).To(BeEmpty())
})
It("parses a long header without a source connection ID", func() {
It("parses a Long Header without a source connection ID", func() {
data := []byte{
0x80 ^ uint8(protocol.PacketTypeInitial),
0x1, 0x2, 0x3, 0x4, // version number
@@ -131,13 +131,13 @@ var _ = Describe("Header Parsing", func() {
data = append(data, encodeVarInt(0x42)...) // payload length
data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...)
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
Expect(iHdr.SrcConnectionID).To(BeEmpty())
Expect(iHdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
})
It("parses a long header with a 2 byte packet number", func() {
It("parses a Long Header with a 2 byte packet number", func() {
data := []byte{
0x80 ^ uint8(protocol.PacketTypeInitial),
0x1, 0x2, 0x3, 0x4, // version number
@@ -146,7 +146,7 @@ var _ = Describe("Header Parsing", func() {
data = append(data, encodeVarInt(0x42)...) // payload length
data = appendPacketNumber(data, 0x123, protocol.PacketNumberLen2)
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionIETFHeader)
Expect(err).ToNot(HaveOccurred())
@@ -167,7 +167,7 @@ var _ = Describe("Header Parsing", func() {
}).Write(buf, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(buf.Bytes())
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
_, err = iHdr.Parse(b, protocol.PerspectiveClient, versionIETFHeader)
Expect(err).To(MatchError("InvalidPacketHeader: Received packet with invalid packet type: 42"))
@@ -182,7 +182,7 @@ var _ = Describe("Header Parsing", func() {
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // source connection ID
}
for i := 0; i < len(data); i++ {
_, err := ParseInvariantHeader(bytes.NewReader(data[:i]))
_, err := ParseInvariantHeader(bytes.NewReader(data[:i]), 0)
Expect(err).To(Equal(io.EOF))
}
})
@@ -198,7 +198,7 @@ var _ = Describe("Header Parsing", func() {
data = appendPacketNumber(data, 0xdeadbeef, protocol.PacketNumberLen4)
for i := iHdrLen; i < len(data); i++ {
b := bytes.NewReader(data[:i])
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
_, err = iHdr.Parse(b, protocol.PerspectiveServer, versionIETFHeader)
Expect(err).To(Equal(io.EOF))
@@ -207,12 +207,12 @@ var _ = Describe("Header Parsing", func() {
})
Context("Short Headers", func() {
It("reads a short header with a connection ID", func() {
It("reads a Short Header with a 8 byte connection ID", func() {
connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}
data := append([]byte{0x30}, connID...)
data = appendPacketNumber(data, 0x42, protocol.PacketNumberLen1)
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 8)
Expect(err).ToNot(HaveOccurred())
Expect(iHdr.IsLongHeader).To(BeFalse())
Expect(iHdr.DestConnectionID).To(Equal(connID))
@@ -226,14 +226,31 @@ var _ = Describe("Header Parsing", func() {
Expect(b.Len()).To(BeZero())
})
It("reads a Short Header with a 5 byte connection ID", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5}
data := append([]byte{0x30}, connID...)
data = appendPacketNumber(data, 0x42, protocol.PacketNumberLen1)
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b, 5)
Expect(err).ToNot(HaveOccurred())
Expect(iHdr.IsLongHeader).To(BeFalse())
Expect(iHdr.DestConnectionID).To(Equal(connID))
hdr, err := iHdr.Parse(b, protocol.PerspectiveClient, versionIETFHeader)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.KeyPhase).To(Equal(0))
Expect(hdr.DestConnectionID).To(Equal(connID))
Expect(hdr.SrcConnectionID).To(BeEmpty())
Expect(b.Len()).To(BeZero())
})
It("reads the Key Phase Bit", func() {
data := []byte{
0x30 ^ 0x40,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID
}
data = appendPacketNumber(data, 11, protocol.PacketNumberLen1)
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 6)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionIETFHeader)
Expect(err).ToNot(HaveOccurred())
@@ -245,11 +262,11 @@ var _ = Describe("Header Parsing", func() {
It("reads a header with a 2 byte packet number", func() {
data := []byte{
0x30 ^ 0x40 ^ 0x1,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
0xde, 0xad, 0xbe, 0xef, // connection ID
}
data = appendPacketNumber(data, 0x1337, protocol.PacketNumberLen2)
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 4)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(b, protocol.PerspectiveClient, versionIETFHeader)
Expect(err).ToNot(HaveOccurred())
@@ -262,11 +279,11 @@ var _ = Describe("Header Parsing", func() {
It("reads a header with a 4 byte packet number", func() {
data := []byte{
0x30 ^ 0x40 ^ 0x2,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x1, 0x2, 0x3, 0x4, // connection ID
}
data = appendPacketNumber(data, 0x99beef, protocol.PacketNumberLen4)
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 10)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionIETFHeader)
Expect(err).ToNot(HaveOccurred())
@@ -282,7 +299,7 @@ var _ = Describe("Header Parsing", func() {
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
}
for i := 0; i < len(data); i++ {
_, err := ParseInvariantHeader(bytes.NewReader(data[:i]))
_, err := ParseInvariantHeader(bytes.NewReader(data[:i]), 8)
Expect(err).To(Equal(io.EOF))
}
})
@@ -290,13 +307,13 @@ var _ = Describe("Header Parsing", func() {
It("errors on EOF, when parsing the invariant header", func() {
data := []byte{
0x30 ^ 0x2,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID
}
iHdrLen := len(data)
data = appendPacketNumber(data, 0xdeadbeef, protocol.PacketNumberLen4)
for i := iHdrLen; i < len(data); i++ {
b := bytes.NewReader(data[:i])
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 6)
Expect(err).ToNot(HaveOccurred())
_, err = iHdr.Parse(b, protocol.PerspectiveClient, versionIETFHeader)
Expect(err).To(Equal(io.EOF))
@@ -307,10 +324,14 @@ var _ = Describe("Header Parsing", func() {
Context("Public Header", func() {
It("accepts a sample client header", func() {
ver := make([]byte, 4)
binary.BigEndian.PutUint32(ver, uint32(protocol.SupportedVersions[0]))
b := bytes.NewReader(append(append([]byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}, ver...), 0x01))
iHdr, err := ParseInvariantHeader(b)
data := []byte{
0x9,
0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6,
}
data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...)
data = append(data, 0x1) // packet number
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
Expect(iHdr.IsLongHeader).To(BeFalse())
hdr, err := iHdr.Parse(b, protocol.PerspectiveClient, versionPublicHeader)
@@ -321,7 +342,7 @@ var _ = Describe("Header Parsing", func() {
connID := protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
Expect(hdr.DestConnectionID).To(Equal(connID))
Expect(hdr.SrcConnectionID).To(BeEmpty())
Expect(hdr.Version).To(Equal(protocol.SupportedVersions[0]))
Expect(hdr.Version).To(Equal(protocol.VersionNumber(0xdeadbeef)))
Expect(hdr.SupportedVersions).To(BeEmpty())
Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(1)))
Expect(b.Len()).To(BeZero())
@@ -329,7 +350,7 @@ var _ = Describe("Header Parsing", func() {
It("accepts an omitted connection ID", func() {
b := bytes.NewReader([]byte{0x0, 0x1})
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 8)
Expect(err).ToNot(HaveOccurred())
Expect(iHdr.IsLongHeader).To(BeFalse())
Expect(iHdr.DestConnectionID).To(BeEmpty())
@@ -342,7 +363,7 @@ var _ = Describe("Header Parsing", func() {
It("parses a PUBLIC_RESET packet", func() {
b := bytes.NewReader([]byte{0xa, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8})
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 4)
Expect(err).ToNot(HaveOccurred())
Expect(iHdr.IsLongHeader).To(BeFalse())
hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader)
@@ -359,7 +380,7 @@ var _ = Describe("Header Parsing", func() {
divNonce := []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f}
Expect(divNonce).To(HaveLen(32))
b := bytes.NewReader(append(append([]byte{0x0c, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}, divNonce...), 0x37))
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 7)
Expect(err).ToNot(HaveOccurred())
Expect(iHdr.IsLongHeader).To(BeFalse())
hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader)
@@ -380,7 +401,7 @@ var _ = Describe("Header Parsing", func() {
data = append(data, []byte{0x13, 37}...) // packet number
for i := iHdrLen; i < len(data); i++ {
b := bytes.NewReader(data[:i])
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 5)
Expect(err).ToNot(HaveOccurred())
_, err = iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader)
Expect(err).To(Equal(io.EOF))
@@ -398,7 +419,7 @@ var _ = Describe("Header Parsing", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
versions := []protocol.VersionNumber{0x13, 0x37}
b := bytes.NewReader(ComposeGQUICVersionNegotiation(connID, versions))
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 6)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader)
Expect(err).ToNot(HaveOccurred())
@@ -416,7 +437,7 @@ var _ = Describe("Header Parsing", func() {
It("errors if it doesn't contain any versions", func() {
b := bytes.NewReader([]byte{0x9, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c})
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 4)
Expect(err).ToNot(HaveOccurred())
_, err = iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader)
Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list"))
@@ -428,7 +449,7 @@ var _ = Describe("Header Parsing", func() {
data = appendVersion(data, protocol.SupportedVersions[0])
data = appendVersion(data, 99) // unsupported version
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader)
Expect(err).ToNot(HaveOccurred())
@@ -442,7 +463,7 @@ var _ = Describe("Header Parsing", func() {
data := ComposeGQUICVersionNegotiation(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, protocol.SupportedVersions)
data = append(data, []byte{0x13, 0x37}...)
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
_, err = iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader)
Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket))
@@ -452,7 +473,7 @@ var _ = Describe("Header Parsing", func() {
Context("Packet Number lengths", func() {
It("accepts 1-byte packet numbers", func() {
b := bytes.NewReader([]byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xde})
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader)
Expect(err).ToNot(HaveOccurred())
@@ -463,7 +484,7 @@ var _ = Describe("Header Parsing", func() {
It("accepts 2-byte packet numbers", func() {
b := bytes.NewReader([]byte{0x18, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xde, 0xca})
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader)
Expect(err).ToNot(HaveOccurred())
@@ -474,7 +495,7 @@ var _ = Describe("Header Parsing", func() {
It("accepts 4-byte packet numbers", func() {
b := bytes.NewReader([]byte{0x28, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xad, 0xfb, 0xca, 0xde})
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader)
Expect(err).ToNot(HaveOccurred())

View File

@@ -84,21 +84,6 @@ var _ = Describe("Header", func() {
Expect(err).To(MatchError("invalid connection ID length: 19 bytes"))
})
It("refuses to write a Long Header with the wrong connection ID length", func() {
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
Expect(srcConnID).ToNot(Equal(protocol.ConnectionIDLen))
err := (&Header{
IsLongHeader: true,
Type: 0x5,
SrcConnectionID: srcConnID,
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}, // connection IDs must be at most 18 bytes long
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
Version: 0x1020304,
}).Write(buf, protocol.PerspectiveServer, versionIETFHeader)
Expect(err).To(MatchError("Header: source connection ID must be 8 bytes, is 6"))
})
It("writes a header with an 18 byte connection ID", func() {
err := (&Header{
IsLongHeader: true,
@@ -537,7 +522,7 @@ var _ = Describe("Header", func() {
data, err := ComposeVersionNegotiation(destConnID, srcConnID, []protocol.VersionNumber{0x12345678, 0x87654321})
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 4)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionIETFHeader)
Expect(err).ToNot(HaveOccurred())

View File

@@ -14,7 +14,7 @@ var _ = Describe("Version Negotiation Packets", func() {
versions := []protocol.VersionNumber{1001, 1003}
data := ComposeGQUICVersionNegotiation(connID, versions)
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 4)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionPublicHeader)
Expect(err).ToNot(HaveOccurred())
@@ -32,7 +32,7 @@ var _ = Describe("Version Negotiation Packets", func() {
Expect(err).ToNot(HaveOccurred())
Expect(data[0] & 0x80).ToNot(BeZero())
b := bytes.NewReader(data)
iHdr, err := ParseInvariantHeader(b)
iHdr, err := ParseInvariantHeader(b, 4)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, versionIETFHeader)
Expect(err).ToNot(HaveOccurred())

View File

@@ -36,15 +36,16 @@ func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder {
}
// AddConn mocks base method
func (m *MockMultiplexer) AddConn(arg0 net.PacketConn) packetHandlerManager {
ret := m.ctrl.Call(m, "AddConn", arg0)
func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int) (packetHandlerManager, error) {
ret := m.ctrl.Call(m, "AddConn", arg0, arg1)
ret0, _ := ret[0].(packetHandlerManager)
return ret0
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AddConn indicates an expected call of AddConn
func (mr *MockMultiplexerMockRecorder) AddConn(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0)
func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1)
}
// AddHandler mocks base method

View File

@@ -65,7 +65,7 @@ var _ = Describe("Packet packer", func() {
checkPayloadLen := func(data []byte) {
r := bytes.NewReader(data)
iHdr, err := wire.ParseInvariantHeader(r)
iHdr, err := wire.ParseInvariantHeader(r, 0)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())

View File

@@ -241,6 +241,10 @@ func populateServerConfig(config *Config) *Config {
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
connIDLen := config.ConnectionIDLength
if connIDLen == 0 {
connIDLen = protocol.DefaultConnectionIDLength
}
return &Config{
Versions: versions,
@@ -252,6 +256,7 @@ func populateServerConfig(config *Config) *Config {
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
ConnectionIDLength: connIDLen,
}
}
@@ -304,7 +309,7 @@ func (s *server) handlePacket(remoteAddr net.Addr, packet []byte) error {
rcvTime := time.Now()
r := bytes.NewReader(packet)
iHdr, err := wire.ParseInvariantHeader(r)
iHdr, err := wire.ParseInvariantHeader(r, s.config.ConnectionIDLength)
if err != nil {
return qerr.Error(qerr.InvalidPacketHeader, err.Error())
}

View File

@@ -48,6 +48,7 @@ var _ = Describe("Server", func() {
RequestConnectionIDOmission: true,
MaxIncomingStreams: 1234,
MaxIncomingUniStreams: 4321,
ConnectionIDLength: 12,
}
c := populateServerConfig(config)
Expect(c.HandshakeTimeout).To(Equal(1337 * time.Minute))
@@ -55,6 +56,7 @@ var _ = Describe("Server", func() {
Expect(c.RequestConnectionIDOmission).To(BeFalse())
Expect(c.MaxIncomingStreams).To(Equal(1234))
Expect(c.MaxIncomingUniStreams).To(Equal(4321))
Expect(c.ConnectionIDLength).To(Equal(12))
})
It("disables bidirectional streams", func() {
@@ -76,6 +78,12 @@ var _ = Describe("Server", func() {
Expect(c.MaxIncomingStreams).To(Equal(1234))
Expect(c.MaxIncomingUniStreams).To(BeZero())
})
It("doesn't use 0-byte connection IDs", func() {
config := &Config{}
c := populateClientConfig(config, true)
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
})
})
Context("with mock session", func() {
@@ -500,7 +508,7 @@ var _ = Describe("Server", func() {
Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero())
Expect(conn.dataWrittenTo).To(Equal(udpAddr))
r := bytes.NewReader(conn.dataWritten.Bytes())
iHdr, err := wire.ParseInvariantHeader(r)
iHdr, err := wire.ParseInvariantHeader(r, 0)
Expect(err).ToNot(HaveOccurred())
Expect(iHdr.IsLongHeader).To(BeFalse())
replyHdr, err := iHdr.Parse(r, protocol.PerspectiveServer, versionIETFFrames)
@@ -546,7 +554,7 @@ var _ = Describe("Server", func() {
Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero())
Expect(conn.dataWrittenTo).To(Equal(udpAddr))
r := bytes.NewReader(conn.dataWritten.Bytes())
iHdr, err := wire.ParseInvariantHeader(r)
iHdr, err := wire.ParseInvariantHeader(r, 0)
Expect(err).ToNot(HaveOccurred())
replyHdr, err := iHdr.Parse(r, protocol.PerspectiveServer, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())

View File

@@ -194,11 +194,15 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
StreamID: version.CryptoStreamID(),
Data: bc.GetDataForWriting(),
}
srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
if err != nil {
return nil, nil, err
}
replyHdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
DestConnectionID: hdr.SrcConnectionID,
SrcConnectionID: hdr.DestConnectionID,
SrcConnectionID: srcConnID,
PayloadLen: f.Length(version) + protocol.ByteCount(aead.Overhead()),
PacketNumber: hdr.PacketNumber, // echo the client's packet number
PacketNumberLen: hdr.PacketNumberLen,
@@ -224,7 +228,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State())
}
params := <-paramsChan
connID, err := protocol.GenerateConnectionID()
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
if err != nil {
return nil, nil, err
}

View File

@@ -71,16 +71,16 @@ var _ = Describe("Stateless TLS handling", func() {
return hdr, data
}
unpackPacket := func(data []byte) (*wire.Header, []byte) {
unpackPacket := func(data []byte, clientDestConnID protocol.ConnectionID) (*wire.Header, []byte) {
r := bytes.NewReader(conn.dataWritten.Bytes())
iHdr, err := wire.ParseInvariantHeader(r)
iHdr, err := wire.ParseInvariantHeader(r, 0)
Expect(err).ToNot(HaveOccurred())
hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
hdr.Raw = data[:len(data)-r.Len()]
var payload []byte
if r.Len() > 0 {
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, hdr.SrcConnectionID, protocol.VersionTLS)
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, clientDestConnID, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
payload, err = aead.Open(nil, data[len(data)-r.Len():], hdr.PacketNumber, hdr.Raw)
Expect(err).ToNot(HaveOccurred())
@@ -97,7 +97,7 @@ var _ = Describe("Stateless TLS handling", func() {
}
server.HandleInitial(nil, hdr, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize))
Expect(conn.dataWritten.Len()).ToNot(BeZero())
replyHdr, _ := unpackPacket(conn.dataWritten.Bytes())
replyHdr, _ := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID)
Expect(replyHdr.IsVersionNegotiation).To(BeTrue())
Expect(sessionChan).ToNot(Receive())
})
@@ -134,9 +134,9 @@ var _ = Describe("Stateless TLS handling", func() {
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
server.HandleInitial(nil, hdr, data)
Expect(conn.dataWritten.Len()).ToNot(BeZero())
replyHdr, payload := unpackPacket(conn.dataWritten.Bytes())
replyHdr, payload := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(replyHdr.PayloadLen).To(BeEquivalentTo(len(payload) + 16 /* AEAD overhead */))
Expect(sessionChan).ToNot(Receive())
@@ -187,7 +187,7 @@ var _ = Describe("Stateless TLS handling", func() {
// the Handshake packet is written by the session
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
// unpack the packet to check that it actually contains a CONNECTION_CLOSE
replyHdr, data := unpackPacket(conn.dataWritten.Bytes())
replyHdr, data := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeHandshake))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))

View File

@@ -1769,7 +1769,7 @@ var _ = Describe("Client Session", func() {
protocol.Version39,
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
nil,
populateClientConfig(&Config{}),
populateClientConfig(&Config{}, false),
protocol.VersionWhatever,
nil,
utils.DefaultLogger,
@@ -1823,7 +1823,7 @@ var _ = Describe("Client Session", func() {
sess.queueControlFrame(&wire.PingFrame{})
var packet []byte
Eventually(mconn.written).Should(Receive(&packet))
hdr, err := wire.ParseInvariantHeader(bytes.NewReader(packet))
hdr, err := wire.ParseInvariantHeader(bytes.NewReader(packet), 0)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}))
// make sure the go routine returns