forked from quic-go/quic-go
also use the multiplexer for the server
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -24,6 +23,7 @@ type packetHandlerMap struct {
|
||||
connIDLen int
|
||||
|
||||
handlers map[string] /* string(ConnectionID)*/ packetHandler
|
||||
server unknownPacketHandler
|
||||
closed bool
|
||||
|
||||
deleteClosedSessionsAfter time.Duration
|
||||
@@ -33,8 +33,7 @@ type packetHandlerMap struct {
|
||||
|
||||
var _ packetHandlerManager = &packetHandlerMap{}
|
||||
|
||||
// TODO(#561): remove the listen flag
|
||||
func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger, listen bool) packetHandlerManager {
|
||||
func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager {
|
||||
m := &packetHandlerMap{
|
||||
conn: conn,
|
||||
connIDLen: connIDLen,
|
||||
@@ -42,19 +41,10 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger
|
||||
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
|
||||
logger: logger,
|
||||
}
|
||||
if listen {
|
||||
go m.listen()
|
||||
}
|
||||
go m.listen()
|
||||
return m
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
|
||||
h.mutex.RLock()
|
||||
sess, ok := h.handlers[string(id)]
|
||||
h.mutex.RUnlock()
|
||||
return sess, ok
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
|
||||
h.mutex.Lock()
|
||||
h.handlers[string(id)] = handler
|
||||
@@ -62,18 +52,47 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
||||
h.removeByConnectionIDAsString(string(id))
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
|
||||
h.mutex.Lock()
|
||||
h.handlers[string(id)] = nil
|
||||
h.handlers[id] = nil
|
||||
h.mutex.Unlock()
|
||||
|
||||
time.AfterFunc(h.deleteClosedSessionsAfter, func() {
|
||||
h.mutex.Lock()
|
||||
delete(h.handlers, string(id))
|
||||
delete(h.handlers, id)
|
||||
h.mutex.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Close() error {
|
||||
func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
|
||||
h.mutex.Lock()
|
||||
h.server = s
|
||||
h.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) CloseServer() {
|
||||
h.mutex.Lock()
|
||||
h.server = nil
|
||||
var wg sync.WaitGroup
|
||||
for id, handler := range h.handlers {
|
||||
if handler != nil && handler.GetPerspective() == protocol.PerspectiveServer {
|
||||
wg.Add(1)
|
||||
go func(id string, handler packetHandler) {
|
||||
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
|
||||
_ = handler.Close()
|
||||
h.removeByConnectionIDAsString(id)
|
||||
wg.Done()
|
||||
}(id, handler)
|
||||
}
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) close(e error) error {
|
||||
h.mutex.Lock()
|
||||
if h.closed {
|
||||
h.mutex.Unlock()
|
||||
@@ -86,12 +105,15 @@ func (h *packetHandlerMap) Close() error {
|
||||
if handler != nil {
|
||||
wg.Add(1)
|
||||
go func(handler packetHandler) {
|
||||
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
|
||||
_ = handler.Close()
|
||||
handler.destroy(e)
|
||||
wg.Done()
|
||||
}(handler)
|
||||
}
|
||||
}
|
||||
|
||||
if h.server != nil {
|
||||
h.server.closeWithError(e)
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
wg.Wait()
|
||||
return nil
|
||||
@@ -105,9 +127,7 @@ func (h *packetHandlerMap) listen() {
|
||||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||
n, addr, err := h.conn.ReadFrom(data)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
h.Close()
|
||||
}
|
||||
h.close(err)
|
||||
return
|
||||
}
|
||||
data = data[:n]
|
||||
@@ -127,15 +147,33 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing invariant header: %s", err)
|
||||
}
|
||||
handler, ok := h.Get(iHdr.DestConnectionID)
|
||||
if !ok {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
|
||||
}
|
||||
if handler == nil {
|
||||
|
||||
h.mutex.RLock()
|
||||
handler, ok := h.handlers[string(iHdr.DestConnectionID)]
|
||||
server := h.server
|
||||
h.mutex.RUnlock()
|
||||
|
||||
var sentBy protocol.Perspective
|
||||
var version protocol.VersionNumber
|
||||
var handlePacket func(*receivedPacket)
|
||||
if ok && handler == nil {
|
||||
// Late packet for closed session
|
||||
return nil
|
||||
}
|
||||
hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, handler.GetVersion())
|
||||
if !ok {
|
||||
if server == nil { // no server set
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
|
||||
}
|
||||
handlePacket = server.handlePacket
|
||||
sentBy = protocol.PerspectiveClient
|
||||
version = iHdr.Version
|
||||
} else {
|
||||
sentBy = handler.GetPerspective().Opposite()
|
||||
version = handler.GetVersion()
|
||||
handlePacket = handler.handlePacket
|
||||
}
|
||||
|
||||
hdr, err := iHdr.Parse(r, sentBy, version)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing header: %s", err)
|
||||
}
|
||||
@@ -150,7 +188,7 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
|
||||
// TODO(#1312): implement parsing of compound packets
|
||||
}
|
||||
|
||||
handler.handlePacket(&receivedPacket{
|
||||
handlePacket(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
header: hdr,
|
||||
data: packetData,
|
||||
|
||||
Reference in New Issue
Block a user