diff --git a/packet_handler_map.go b/packet_handler_map.go index ff7bd7b74..919c2304b 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -57,11 +57,6 @@ type rawConn interface { io.Closer } -type packetHandlerMapEntry struct { - packetHandler packetHandler - is0RTTQueue bool -} - // The packetHandlerMap stores packetHandlers, identified by connection ID. // It is used: // * by the server to store connections @@ -72,7 +67,7 @@ type packetHandlerMap struct { conn rawConn connIDLen int - handlers map[string] /* string(ConnectionID)*/ packetHandlerMapEntry + handlers map[string] /* string(ConnectionID)*/ packetHandler resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler server unknownPacketHandler numZeroRTTEntries int @@ -151,7 +146,7 @@ func newPacketHandlerMap( conn: conn, connIDLen: connIDLen, listening: make(chan struct{}), - handlers: make(map[string]packetHandlerMapEntry), + handlers: make(map[string]packetHandler), resetTokens: make(map[protocol.StatelessResetToken]packetHandler), deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, @@ -202,7 +197,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) h.logger.Debugf("Not adding connection ID %s, as it already exists.", id) return false } - h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} + h.handlers[string(id)] = handler h.logger.Debugf("Adding connection ID %s.", id) return true } @@ -212,24 +207,24 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co defer h.mutex.Unlock() var q *zeroRTTQueue - if entry, ok := h.handlers[string(clientDestConnID)]; ok { - if !entry.is0RTTQueue { + if handler, ok := h.handlers[string(clientDestConnID)]; ok { + q, ok = handler.(*zeroRTTQueue) + if !ok { h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) return false } - q = entry.packetHandler.(*zeroRTTQueue) q.retireTimer.Stop() h.numZeroRTTEntries-- if h.numZeroRTTEntries < 0 { panic("number of 0-RTT queues < 0") } } - sess := fn() + conn := fn() if q != nil { - q.EnqueueAll(sess) + q.EnqueueAll(conn) } - h.handlers[string(clientDestConnID)] = packetHandlerMapEntry{packetHandler: sess} - h.handlers[string(newConnID)] = packetHandlerMapEntry{packetHandler: sess} + h.handlers[string(clientDestConnID)] = conn + h.handlers[string(newConnID)] = conn h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) return true } @@ -253,7 +248,7 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) { h.mutex.Lock() - h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} + h.handlers[string(id)] = handler h.mutex.Unlock() h.logger.Debugf("Replacing connection for connection ID %s with a closed connection.", id) @@ -292,14 +287,14 @@ func (h *packetHandlerMap) CloseServer() { } h.server = nil var wg sync.WaitGroup - for _, entry := range h.handlers { - if entry.packetHandler.getPerspective() == protocol.PerspectiveServer { + for _, handler := range h.handlers { + if handler.getPerspective() == protocol.PerspectiveServer { wg.Add(1) go func(handler packetHandler) { // blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped handler.shutdown() wg.Done() - }(entry.packetHandler) + }(handler) } } h.mutex.Unlock() @@ -324,12 +319,12 @@ func (h *packetHandlerMap) close(e error) error { } var wg sync.WaitGroup - for _, entry := range h.handlers { + for _, handler := range h.handlers { wg.Add(1) go func(handler packetHandler) { handler.destroy(e) wg.Done() - }(entry.packetHandler) + }(handler) } if h.server != nil { @@ -379,14 +374,14 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { return } - if entry, ok := h.handlers[string(connID)]; ok { - if entry.is0RTTQueue { // only enqueue 0-RTT packets in the 0-RTT queue + if handler, ok := h.handlers[string(connID)]; ok { + if ha, ok := handler.(*zeroRTTQueue); ok { // only enqueue 0-RTT packets in the 0-RTT queue if wire.Is0RTTPacket(p.data) { - entry.packetHandler.handlePacket(p) + ha.handlePacket(p) return } } else { // existing connection - entry.packetHandler.handlePacket(p) + handler.handlePacket(p) return } } @@ -404,24 +399,23 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { } h.numZeroRTTEntries++ queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)} - h.handlers[string(connID)] = packetHandlerMapEntry{ - packetHandler: queue, - is0RTTQueue: true, - } + h.handlers[string(connID)] = queue queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() { h.mutex.Lock() defer h.mutex.Unlock() // The entry might have been replaced by an actual connection. // Only delete it if it's still a 0-RTT queue. - if entry, ok := h.handlers[string(connID)]; ok && entry.is0RTTQueue { - delete(h.handlers, string(connID)) - h.numZeroRTTEntries-- - if h.numZeroRTTEntries < 0 { - panic("number of 0-RTT queues < 0") - } - entry.packetHandler.(*zeroRTTQueue).Clear() - if h.logger.Debug() { - h.logger.Debugf("Removing 0-RTT queue for %s.", connID) + if handler, ok := h.handlers[string(connID)]; ok { + if q, ok := handler.(*zeroRTTQueue); ok { + delete(h.handlers, string(connID)) + h.numZeroRTTEntries-- + if h.numZeroRTTEntries < 0 { + panic("number of 0-RTT queues < 0") + } + q.Clear() + if h.logger.Debug() { + h.logger.Debugf("Removing 0-RTT queue for %s.", connID) + } } } })