also use the multiplexer for the server

This commit is contained in:
Marten Seemann
2018-07-20 08:26:36 -04:00
parent c8d20e86d7
commit ad5a3e2fa0
15 changed files with 631 additions and 512 deletions

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
"net"
"strings"
"sync"
"time"
@@ -24,6 +23,7 @@ type packetHandlerMap struct {
connIDLen int
handlers map[string] /* string(ConnectionID)*/ packetHandler
server unknownPacketHandler
closed bool
deleteClosedSessionsAfter time.Duration
@@ -33,8 +33,7 @@ type packetHandlerMap struct {
var _ packetHandlerManager = &packetHandlerMap{}
// TODO(#561): remove the listen flag
func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger, listen bool) packetHandlerManager {
func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager {
m := &packetHandlerMap{
conn: conn,
connIDLen: connIDLen,
@@ -42,19 +41,10 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
logger: logger,
}
if listen {
go m.listen()
}
go m.listen()
return m
}
func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
h.mutex.RLock()
sess, ok := h.handlers[string(id)]
h.mutex.RUnlock()
return sess, ok
}
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
h.mutex.Lock()
h.handlers[string(id)] = handler
@@ -62,18 +52,47 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
}
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
h.removeByConnectionIDAsString(string(id))
}
func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
h.mutex.Lock()
h.handlers[string(id)] = nil
h.handlers[id] = nil
h.mutex.Unlock()
time.AfterFunc(h.deleteClosedSessionsAfter, func() {
h.mutex.Lock()
delete(h.handlers, string(id))
delete(h.handlers, id)
h.mutex.Unlock()
})
}
func (h *packetHandlerMap) Close() error {
func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
h.mutex.Lock()
h.server = s
h.mutex.Unlock()
}
func (h *packetHandlerMap) CloseServer() {
h.mutex.Lock()
h.server = nil
var wg sync.WaitGroup
for id, handler := range h.handlers {
if handler != nil && handler.GetPerspective() == protocol.PerspectiveServer {
wg.Add(1)
go func(id string, handler packetHandler) {
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
_ = handler.Close()
h.removeByConnectionIDAsString(id)
wg.Done()
}(id, handler)
}
}
h.mutex.Unlock()
wg.Wait()
}
func (h *packetHandlerMap) close(e error) error {
h.mutex.Lock()
if h.closed {
h.mutex.Unlock()
@@ -86,12 +105,15 @@ func (h *packetHandlerMap) Close() error {
if handler != nil {
wg.Add(1)
go func(handler packetHandler) {
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
_ = handler.Close()
handler.destroy(e)
wg.Done()
}(handler)
}
}
if h.server != nil {
h.server.closeWithError(e)
}
h.mutex.Unlock()
wg.Wait()
return nil
@@ -105,9 +127,7 @@ func (h *packetHandlerMap) listen() {
// If it does, we only read a truncated packet, which will then end up undecryptable
n, addr, err := h.conn.ReadFrom(data)
if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
h.Close()
}
h.close(err)
return
}
data = data[:n]
@@ -127,15 +147,33 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
if err != nil {
return fmt.Errorf("error parsing invariant header: %s", err)
}
handler, ok := h.Get(iHdr.DestConnectionID)
if !ok {
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
}
if handler == nil {
h.mutex.RLock()
handler, ok := h.handlers[string(iHdr.DestConnectionID)]
server := h.server
h.mutex.RUnlock()
var sentBy protocol.Perspective
var version protocol.VersionNumber
var handlePacket func(*receivedPacket)
if ok && handler == nil {
// Late packet for closed session
return nil
}
hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, handler.GetVersion())
if !ok {
if server == nil { // no server set
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
}
handlePacket = server.handlePacket
sentBy = protocol.PerspectiveClient
version = iHdr.Version
} else {
sentBy = handler.GetPerspective().Opposite()
version = handler.GetVersion()
handlePacket = handler.handlePacket
}
hdr, err := iHdr.Parse(r, sentBy, version)
if err != nil {
return fmt.Errorf("error parsing header: %s", err)
}
@@ -150,7 +188,7 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
// TODO(#1312): implement parsing of compound packets
}
handler.handlePacket(&receivedPacket{
handlePacket(&receivedPacket{
remoteAddr: addr,
header: hdr,
data: packetData,