diff --git a/packet_handler_map.go b/packet_handler_map.go index f3e1d019b..e89baf49a 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -183,6 +183,10 @@ func (h *packetHandlerMap) handlePacket( h.mutex.RLock() defer h.mutex.RUnlock() + if isStatelessReset := h.maybeHandleStatelessReset(data); isStatelessReset { + return + } + handlerEntry, handlerFound := h.handlers[string(connID)] p := &receivedPacket{ @@ -195,17 +199,7 @@ func (h *packetHandlerMap) handlePacket( handlerEntry.handler.handlePacket(p) return } - // No session found. - // This might be a stateless reset. - if data[0]&0x80 == 0 { // stateless resets are always short header packets - if len(p.data) >= protocol.MinStatelessResetSize { - var token [16]byte - copy(token[:], p.data[len(p.data)-16:]) - if sess, ok := h.resetTokens[token]; ok { - sess.destroy(errors.New("received a stateless reset")) - return - } - } + if data[0]&0x80 == 0 { // TODO(#943): send a stateless reset h.logger.Debugf("received a short header packet with an unexpected connection ID %s", connID) return @@ -216,3 +210,21 @@ func (h *packetHandlerMap) handlePacket( } h.server.handlePacket(p) } + +func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { + // stateless resets are always short header packets + if data[0]&0x80 != 0 { + return false + } + if len(data) < protocol.MinStatelessResetSize { + return false + } + + var token [16]byte + copy(token[:], data[len(data)-16:]) + if sess, ok := h.resetTokens[token]; ok { + sess.destroy(errors.New("received a stateless reset")) + return true + } + return false +} diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 2f1e03a88..00a2d40e7 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -195,6 +195,21 @@ var _ = Describe("Packet Handler Map", func() { Eventually(destroyed).Should(BeClosed()) }) + It("handles stateless resets for 0-length connection IDs", func() { + handler.connIDLen = 0 + packetHandler := NewMockPacketHandler(mockCtrl) + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddWithResetToken(protocol.ConnectionID{}, packetHandler, token) + packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) + packet = append(packet, token[:]...) + destroyed := make(chan struct{}) + packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { + close(destroyed) + }) + conn.dataToRead <- packet + Eventually(destroyed).Should(BeClosed()) + }) + It("deletes reset tokens when the session is retired", func() { handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42}