diff --git a/packet_handler_map.go b/packet_handler_map.go index e89baf49..0ae65b51 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -11,11 +11,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) -type packetHandlerEntry struct { - handler packetHandler - resetToken *[16]byte -} - // The packetHandlerMap stores packetHandlers, identified by connection ID. // It is used: // * by the server to store sessions @@ -26,7 +21,7 @@ type packetHandlerMap struct { conn net.PacketConn connIDLen int - handlers map[string] /* string(ConnectionID)*/ packetHandlerEntry + handlers map[string] /* string(ConnectionID)*/ packetHandler resetTokens map[[16]byte] /* stateless reset token */ packetHandler server unknownPacketHandler @@ -45,7 +40,7 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger conn: conn, connIDLen: connIDLen, listening: make(chan struct{}), - handlers: make(map[string]packetHandlerEntry), + handlers: make(map[string]packetHandler), resetTokens: make(map[[16]byte]packetHandler), deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, logger: logger, @@ -56,14 +51,7 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) { h.mutex.Lock() - h.handlers[string(id)] = packetHandlerEntry{handler: handler} - h.mutex.Unlock() -} - -func (h *packetHandlerMap) AddWithResetToken(id protocol.ConnectionID, handler packetHandler, token [16]byte) { - h.mutex.Lock() - h.handlers[string(id)] = packetHandlerEntry{handler: handler, resetToken: &token} - h.resetTokens[token] = handler + h.handlers[string(id)] = handler h.mutex.Unlock() } @@ -73,12 +61,7 @@ func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { func (h *packetHandlerMap) removeByConnectionIDAsString(id string) { h.mutex.Lock() - if handlerEntry, ok := h.handlers[id]; ok { - if token := handlerEntry.resetToken; token != nil { - delete(h.resetTokens, *token) - } - delete(h.handlers, id) - } + delete(h.handlers, id) h.mutex.Unlock() } @@ -92,6 +75,18 @@ func (h *packetHandlerMap) retireByConnectionIDAsString(id string) { }) } +func (h *packetHandlerMap) AddResetToken(token [16]byte, handler packetHandler) { + h.mutex.Lock() + h.resetTokens[token] = handler + h.mutex.Unlock() +} + +func (h *packetHandlerMap) RemoveResetToken(token [16]byte) { + h.mutex.Lock() + delete(h.resetTokens, token) + h.mutex.Unlock() +} + func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { h.mutex.Lock() h.server = s @@ -102,8 +97,7 @@ func (h *packetHandlerMap) CloseServer() { h.mutex.Lock() h.server = nil var wg sync.WaitGroup - for id, handlerEntry := range h.handlers { - handler := handlerEntry.handler + for id, handler := range h.handlers { if handler.getPerspective() == protocol.PerspectiveServer { wg.Add(1) go func(id string, handler packetHandler) { @@ -136,12 +130,12 @@ func (h *packetHandlerMap) close(e error) error { h.closed = true var wg sync.WaitGroup - for _, handlerEntry := range h.handlers { + for _, handler := range h.handlers { wg.Add(1) - go func(handlerEntry packetHandlerEntry) { - handlerEntry.handler.destroy(e) + go func(handler packetHandler) { + handler.destroy(e) wg.Done() - }(handlerEntry) + }(handler) } if h.server != nil { @@ -187,7 +181,7 @@ func (h *packetHandlerMap) handlePacket( return } - handlerEntry, handlerFound := h.handlers[string(connID)] + handler, handlerFound := h.handlers[string(connID)] p := &receivedPacket{ remoteAddr: addr, @@ -196,7 +190,7 @@ func (h *packetHandlerMap) handlePacket( data: data, } if handlerFound { // existing session - handlerEntry.handler.handlePacket(p) + handler.handlePacket(p) return } if data[0]&0x80 == 0 { diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 00a2d40e..1b078c1b 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -163,28 +163,10 @@ var _ = Describe("Packet Handler Map", func() { }) Context("stateless reset handling", func() { - It("handles packets for connections added with a reset token", func() { - packetHandler := NewMockPacketHandler(mockCtrl) - connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} - token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddWithResetToken(connID, packetHandler, token) - // first send a normal packet - handledPacket := make(chan struct{}) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - cid, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(cid).To(Equal(connID)) - close(handledPacket) - }) - conn.dataToRead <- getPacket(connID) - Eventually(handledPacket).Should(BeClosed()) - }) - It("handles stateless resets", func() { packetHandler := NewMockPacketHandler(mockCtrl) - connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddWithResetToken(connID, packetHandler, token) + handler.AddResetToken(token, packetHandler) packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) packet = append(packet, token[:]...) destroyed := make(chan struct{}) @@ -199,7 +181,7 @@ var _ = Describe("Packet Handler Map", func() { handler.connIDLen = 0 packetHandler := NewMockPacketHandler(mockCtrl) token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddWithResetToken(protocol.ConnectionID{}, packetHandler, token) + handler.AddResetToken(token, packetHandler) packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) packet = append(packet, token[:]...) destroyed := make(chan struct{}) @@ -210,13 +192,12 @@ var _ = Describe("Packet Handler Map", func() { Eventually(destroyed).Should(BeClosed()) }) - It("deletes reset tokens when the session is retired", func() { + It("deletes reset tokens", func() { handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token) - handler.Retire(connID) - time.Sleep(scaleDuration(30 * time.Millisecond)) + handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) + handler.RemoveResetToken(token) handler.handlePacket(nil, nil, getPacket(connID)) // don't EXPECT any calls to handlePacket of the MockPacketHandler packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...)