forked from quic-go/quic-go
make the connection ID length configurable
This commit is contained in:
@@ -3,6 +3,7 @@ package quic
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -19,16 +20,21 @@ var (
|
||||
)
|
||||
|
||||
type multiplexer interface {
|
||||
AddConn(net.PacketConn) packetHandlerManager
|
||||
AddConn(net.PacketConn, int) (packetHandlerManager, error)
|
||||
AddHandler(net.PacketConn, protocol.ConnectionID, packetHandler) error
|
||||
}
|
||||
|
||||
type connManager struct {
|
||||
connIDLen int
|
||||
manager packetHandlerManager
|
||||
}
|
||||
|
||||
// The clientMultiplexer listens on multiple net.PacketConns and dispatches
|
||||
// incoming packets to the session handler.
|
||||
type clientMultiplexer struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conns map[net.PacketConn]packetHandlerManager
|
||||
conns map[net.PacketConn]connManager
|
||||
newPacketHandlerManager func() packetHandlerManager // so it can be replaced in the tests
|
||||
|
||||
logger utils.Logger
|
||||
@@ -39,7 +45,7 @@ var _ multiplexer = &clientMultiplexer{}
|
||||
func getClientMultiplexer() multiplexer {
|
||||
clientMuxerOnce.Do(func() {
|
||||
clientMuxer = &clientMultiplexer{
|
||||
conns: make(map[net.PacketConn]packetHandlerManager),
|
||||
conns: make(map[net.PacketConn]connManager),
|
||||
logger: utils.DefaultLogger.WithPrefix("client muxer"),
|
||||
newPacketHandlerManager: newPacketHandlerMap,
|
||||
}
|
||||
@@ -47,30 +53,34 @@ func getClientMultiplexer() multiplexer {
|
||||
return clientMuxer
|
||||
}
|
||||
|
||||
func (m *clientMultiplexer) AddConn(c net.PacketConn) packetHandlerManager {
|
||||
func (m *clientMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandlerManager, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
sessions, ok := m.conns[c]
|
||||
p, ok := m.conns[c]
|
||||
if !ok {
|
||||
sessions = m.newPacketHandlerManager()
|
||||
m.conns[c] = sessions
|
||||
manager := m.newPacketHandlerManager()
|
||||
p = connManager{connIDLen: connIDLen, manager: manager}
|
||||
m.conns[c] = p
|
||||
// If we didn't know this packet conn before, listen for incoming packets
|
||||
// and dispatch them to the right sessions.
|
||||
go m.listen(c, sessions)
|
||||
go m.listen(c, p)
|
||||
}
|
||||
return sessions
|
||||
if p.connIDLen != connIDLen {
|
||||
return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen)
|
||||
}
|
||||
return p.manager, nil
|
||||
}
|
||||
|
||||
func (m *clientMultiplexer) AddHandler(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) error {
|
||||
sessions, ok := m.conns[c]
|
||||
p, ok := m.conns[c]
|
||||
if !ok {
|
||||
return errors.New("unknown packet conn %s")
|
||||
}
|
||||
sessions.Add(connID, handler)
|
||||
p.manager.Add(connID, handler)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManager) {
|
||||
func (m *clientMultiplexer) listen(c net.PacketConn, p connManager) {
|
||||
for {
|
||||
data := *getPacketBuffer()
|
||||
data = data[:protocol.MaxReceivePacketSize]
|
||||
@@ -79,7 +89,7 @@ func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManag
|
||||
n, addr, err := c.ReadFrom(data)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
sessions.Close(err)
|
||||
p.manager.Close(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -87,13 +97,13 @@ func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManag
|
||||
rcvTime := time.Now()
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
iHdr, err := wire.ParseInvariantHeader(r, protocol.ConnectionIDLenGQUIC)
|
||||
iHdr, err := wire.ParseInvariantHeader(r, p.connIDLen)
|
||||
// drop the packet if we can't parse the header
|
||||
if err != nil {
|
||||
m.logger.Debugf("error parsing invariant header from %s: %s", addr, err)
|
||||
continue
|
||||
}
|
||||
client, ok := sessions.Get(iHdr.DestConnectionID)
|
||||
client, ok := p.manager.Get(iHdr.DestConnectionID)
|
||||
if !ok {
|
||||
m.logger.Debugf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user