From 39e1e9a6f81f6643cb82fe506a8c50c6e8b4053a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 19 Nov 2018 11:26:26 +0700 Subject: [PATCH] implement receiving of stateless resets --- internal/protocol/params.go | 3 ++ packet_handler_map.go | 69 ++++++++++++++++++++++++++++--------- packet_handler_map_test.go | 48 ++++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 17 deletions(-) diff --git a/internal/protocol/params.go b/internal/protocol/params.go index 58637d90..13847ffd 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -8,6 +8,9 @@ const MaxPacketSizeIPv4 = 1252 // MaxPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets. const MaxPacketSizeIPv6 = 1232 +// MinStatelessResetSize is the minimum size of a stateless reset packet +const MinStatelessResetSize = 1 + 20 + 16 + // NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet // This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames const NonForwardSecurePacketSizeReduction = 50 diff --git a/packet_handler_map.go b/packet_handler_map.go index f2f225d3..32496758 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "errors" "fmt" "net" "sync" @@ -12,6 +13,11 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) +type packetHandlerEntry struct { + handler packetHandler + resetToken *[16]byte +} + // The packetHandlerMap stores packetHandlers, identified by connection ID. // It is used: // * by the server to store sessions @@ -22,9 +28,10 @@ type packetHandlerMap struct { conn net.PacketConn connIDLen int - handlers map[string] /* string(ConnectionID)*/ packetHandler - server unknownPacketHandler - closed bool + handlers map[string] /* string(ConnectionID)*/ packetHandlerEntry + resetTokens map[[16]byte] /* stateless reset token */ packetHandler + server unknownPacketHandler + closed bool deleteRetiredSessionsAfter time.Duration @@ -37,7 +44,8 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger m := &packetHandlerMap{ conn: conn, connIDLen: connIDLen, - handlers: make(map[string]packetHandler), + handlers: make(map[string]packetHandlerEntry), + resetTokens: make(map[[16]byte]packetHandler), deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, logger: logger, } @@ -47,13 +55,29 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) { h.mutex.Lock() - h.handlers[string(id)] = handler + h.handlers[string(id)] = packetHandlerEntry{handler: handler} + h.mutex.Unlock() +} + +func (h *packetHandlerMap) AddWithResetToken(id protocol.ConnectionID, handler packetHandler, token [16]byte) { + h.mutex.Lock() + h.handlers[string(id)] = packetHandlerEntry{handler: handler, resetToken: &token} + h.resetTokens[token] = handler h.mutex.Unlock() } func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { + h.removeByConnectionIDAsString(string(id)) +} + +func (h *packetHandlerMap) removeByConnectionIDAsString(id string) { h.mutex.Lock() - delete(h.handlers, string(id)) + if handlerEntry, ok := h.handlers[id]; ok { + if token := handlerEntry.resetToken; token != nil { + delete(h.resetTokens, *token) + } + delete(h.handlers, id) + } h.mutex.Unlock() } @@ -63,9 +87,7 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { func (h *packetHandlerMap) retireByConnectionIDAsString(id string) { time.AfterFunc(h.deleteRetiredSessionsAfter, func() { - h.mutex.Lock() - delete(h.handlers, id) - h.mutex.Unlock() + h.removeByConnectionIDAsString(id) }) } @@ -79,7 +101,8 @@ func (h *packetHandlerMap) CloseServer() { h.mutex.Lock() h.server = nil var wg sync.WaitGroup - for id, handler := range h.handlers { + for id, handlerEntry := range h.handlers { + handler := handlerEntry.handler if handler.GetPerspective() == protocol.PerspectiveServer { wg.Add(1) go func(id string, handler packetHandler) { @@ -103,12 +126,12 @@ func (h *packetHandlerMap) close(e error) error { h.closed = true var wg sync.WaitGroup - for _, handler := range h.handlers { + for _, handlerEntry := range h.handlers { wg.Add(1) - go func(handler packetHandler) { - handler.destroy(e) + go func(handlerEntry packetHandlerEntry) { + handlerEntry.handler.destroy(e) wg.Done() - }(handler) + }(handlerEntry) } if h.server != nil { @@ -149,25 +172,37 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { } h.mutex.RLock() - handler, ok := h.handlers[string(iHdr.DestConnectionID)] + handlerEntry, handlerFound := h.handlers[string(iHdr.DestConnectionID)] server := h.server - h.mutex.RUnlock() var sentBy protocol.Perspective var version protocol.VersionNumber var handlePacket func(*receivedPacket) - if ok { // existing session + if handlerFound { // existing session + handler := handlerEntry.handler sentBy = handler.GetPerspective().Opposite() version = handler.GetVersion() handlePacket = handler.handlePacket } else { // no session found + // this might be a stateless reset + if !iHdr.IsLongHeader && len(data) >= protocol.MinStatelessResetSize { + var token [16]byte + copy(token[:], data[len(data)-16:]) + if sess, ok := h.resetTokens[token]; ok { + h.mutex.RUnlock() + sess.destroy(errors.New("received a stateless reset")) + return nil + } + } if server == nil { // no server set + h.mutex.RUnlock() return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID) } handlePacket = server.handlePacket sentBy = protocol.PerspectiveClient version = iHdr.Version } + h.mutex.RUnlock() hdr, err := iHdr.Parse(r, sentBy, version) if err != nil { diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index a5d22ccb..dd2202a7 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -184,6 +184,54 @@ var _ = Describe("Packet Handler Map", func() { }) }) + Context("stateless reset handling", func() { + It("handles packets for connections added with a reset token", func() { + packetHandler := NewMockPacketHandler(mockCtrl) + connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddWithResetToken(connID, packetHandler, token) + // first send a normal packet + handledPacket := make(chan struct{}) + packetHandler.EXPECT().GetPerspective() + packetHandler.EXPECT().GetVersion() + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + Expect(p.header.DestConnectionID).To(Equal(connID)) + close(handledPacket) + }) + conn.dataToRead <- getPacket(connID) + Eventually(handledPacket).Should(BeClosed()) + }) + + It("handles stateless resets", func() { + packetHandler := NewMockPacketHandler(mockCtrl) + connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddWithResetToken(connID, packetHandler, token) + packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) + packet = append(packet, token[:]...) + destroyed := make(chan struct{}) + packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { + close(destroyed) + }) + conn.dataToRead <- packet + Eventually(destroyed).Should(BeClosed()) + }) + + It("deletes reset tokens when the session is retired", func() { + handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token) + handler.Retire(connID) + time.Sleep(scaleDuration(30 * time.Millisecond)) + Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0xdeadbeef42")) + packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...) + packet = append(packet, token[:]...) + Expect(handler.handlePacket(nil, packet)).To(MatchError("received a packet with an unexpected connection ID 0xdecafbad99")) + Expect(handler.resetTokens).To(BeEmpty()) + }) + }) + Context("running a server", func() { It("adds a server", func() { connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}