diff --git a/interface.go b/interface.go index 4827c7d9..6402d076 100644 --- a/interface.go +++ b/interface.go @@ -209,6 +209,7 @@ type Config struct { // If set to a negative value, it doesn't allow any unidirectional streams. MaxIncomingUniStreams int // The StatelessResetKey is used to generate stateless reset tokens. + // If no key is configured, sending of stateless resets is disabled. StatelessResetKey []byte // KeepAlive defines whether this peer will periodically send a packet to keep the connection alive. KeepAlive bool diff --git a/packet_handler_map.go b/packet_handler_map.go index 43c3f24b..0af6a407 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -34,7 +34,8 @@ type packetHandlerMap struct { deleteRetiredSessionsAfter time.Duration - statelessResetHasher hash.Hash + statelessResetEnabled bool + statelessResetHasher hash.Hash logger utils.Logger } @@ -54,6 +55,7 @@ func newPacketHandlerMap( handlers: make(map[string]packetHandler), resetTokens: make(map[[16]byte]packetHandler), deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, + statelessResetEnabled: len(statelessResetKey) > 0, statelessResetHasher: hmac.New(sha256.New, statelessResetKey), logger: logger, } @@ -236,8 +238,15 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { } func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte { - h.statelessResetHasher.Write(connID.Bytes()) var token [16]byte + if !h.statelessResetEnabled { + // Return a random stateless reset token. + // This token will be sent in the server's transport parameters. + // By using a random token, an off-path attacker won't be able to disrupt the connection. + rand.Read(token[:]) + return token + } + h.statelessResetHasher.Write(connID.Bytes()) copy(token[:], h.statelessResetHasher.Sum(nil)) h.statelessResetHasher.Reset() return token @@ -245,6 +254,9 @@ func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { defer p.buffer.Release() + if !h.statelessResetEnabled { + return + } // Don't send a stateless reset in response to very small packets. // This includes packets that could be stateless resets. if len(p.data) <= protocol.MinStatelessResetSize { diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 051f518e..e52c1305 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "crypto/rand" "errors" "net" "time" @@ -18,6 +19,9 @@ var _ = Describe("Packet Handler Map", func() { var ( handler *packetHandlerMap conn *mockPacketConn + + connIDLen int + statelessResetKey []byte ) getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) []byte { @@ -40,8 +44,13 @@ var _ = Describe("Packet Handler Map", func() { } BeforeEach(func() { + statelessResetKey = nil + connIDLen = 0 + }) + + JustBeforeEach(func() { conn = newMockPacketConn() - handler = newPacketHandlerMap(conn, 5, nil, utils.DefaultLogger).(*packetHandlerMap) + handler = newPacketHandlerMap(conn, connIDLen, statelessResetKey, utils.DefaultLogger).(*packetHandlerMap) }) AfterEach(func() { @@ -80,6 +89,10 @@ var _ = Describe("Packet Handler Map", func() { }) Context("handling packets", func() { + BeforeEach(func() { + connIDLen = 5 + }) + It("handles packets for different packet handlers on the same packet conn", func() { connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} @@ -163,81 +176,6 @@ var _ = Describe("Packet Handler Map", func() { }) }) - Context("stateless reset handling", func() { - It("generates stateless reset tokens", func() { - connID1 := []byte{0xde, 0xad, 0xbe, 0xef} - connID2 := []byte{0xde, 0xca, 0xfb, 0xad} - token1 := handler.GetStatelessResetToken(connID1) - Expect(handler.GetStatelessResetToken(connID1)).To(Equal(token1)) - Expect(handler.GetStatelessResetToken(connID2)).ToNot(Equal(token1)) - }) - - It("handles stateless resets", func() { - packetHandler := NewMockPacketHandler(mockCtrl) - token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - destroyed := make(chan struct{}) - packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { - close(destroyed) - }) - conn.dataToRead <- packet - Eventually(destroyed).Should(BeClosed()) - }) - - It("handles stateless resets for 0-length connection IDs", func() { - handler.connIDLen = 0 - packetHandler := NewMockPacketHandler(mockCtrl) - token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - destroyed := make(chan struct{}) - packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { - close(destroyed) - }) - conn.dataToRead <- packet - Eventually(destroyed).Should(BeClosed()) - }) - - It("deletes reset tokens", func() { - handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} - packetHandler := NewMockPacketHandler(mockCtrl) - handler.Add(connID, packetHandler) - token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) - handler.RemoveResetToken(token) - packetHandler.EXPECT().handlePacket(gomock.Any()) - p := append([]byte{0x40} /* short header packet */, connID.Bytes()...) - p = append(p, make([]byte, 50)...) - p = append(p, token[:]...) - handler.handlePacket(nil, nil, p) - // destroy() would be called from a separate go routine - // make sure we give it enough time to be called to cause an error here - time.Sleep(scaleDuration(25 * time.Millisecond)) - }) - - It("sends stateless resets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, 100)...) - handler.handlePacket(addr, getPacketBuffer(), p) - var reset mockPacketConnWrite - Eventually(conn.dataWritten).Should(Receive(&reset)) - Expect(reset.to).To(Equal(addr)) - Expect(reset.data[0] & 0x80).To(BeZero()) // short header packet - Expect(reset.data).To(HaveLen(protocol.MinStatelessResetSize)) - }) - - It("doesn't send stateless resets for small packets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...) - handler.handlePacket(addr, getPacketBuffer(), p) - Consistently(conn.dataWritten).ShouldNot(Receive()) - }) - }) - Context("running a server", func() { It("adds a server", func() { connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} @@ -274,4 +212,102 @@ var _ = Describe("Packet Handler Map", func() { handler.handlePacket(nil, nil, p) }) }) + + Context("stateless resets", func() { + BeforeEach(func() { + connIDLen = 5 + }) + + Context("handling", func() { + It("handles stateless resets", func() { + packetHandler := NewMockPacketHandler(mockCtrl) + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, packetHandler) + packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) + packet = append(packet, token[:]...) + destroyed := make(chan struct{}) + packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { + close(destroyed) + }) + conn.dataToRead <- packet + Eventually(destroyed).Should(BeClosed()) + }) + + It("handles stateless resets for 0-length connection IDs", func() { + handler.connIDLen = 0 + packetHandler := NewMockPacketHandler(mockCtrl) + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, packetHandler) + packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) + packet = append(packet, token[:]...) + destroyed := make(chan struct{}) + packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { + close(destroyed) + }) + conn.dataToRead <- packet + Eventually(destroyed).Should(BeClosed()) + }) + + It("deletes reset tokens", func() { + handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} + packetHandler := NewMockPacketHandler(mockCtrl) + handler.Add(connID, packetHandler) + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) + handler.RemoveResetToken(token) + packetHandler.EXPECT().handlePacket(gomock.Any()) + p := append([]byte{0x40} /* short header packet */, connID.Bytes()...) + p = append(p, make([]byte, 50)...) + p = append(p, token[:]...) + handler.handlePacket(nil, nil, p) + // destroy() would be called from a separate go routine + // make sure we give it enough time to be called to cause an error here + time.Sleep(scaleDuration(25 * time.Millisecond)) + }) + }) + + Context("generating", func() { + BeforeEach(func() { + key := make([]byte, 32) + rand.Read(key) + statelessResetKey = key + }) + + It("generates stateless reset tokens", func() { + connID1 := []byte{0xde, 0xad, 0xbe, 0xef} + connID2 := []byte{0xde, 0xca, 0xfb, 0xad} + token1 := handler.GetStatelessResetToken(connID1) + Expect(handler.GetStatelessResetToken(connID1)).To(Equal(token1)) + Expect(handler.GetStatelessResetToken(connID2)).ToNot(Equal(token1)) + }) + + It("sends stateless resets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, 100)...) + handler.handlePacket(addr, getPacketBuffer(), p) + var reset mockPacketConnWrite + Eventually(conn.dataWritten).Should(Receive(&reset)) + Expect(reset.to).To(Equal(addr)) + Expect(reset.data[0] & 0x80).To(BeZero()) // short header packet + Expect(reset.data).To(HaveLen(protocol.MinStatelessResetSize)) + }) + + It("doesn't send stateless resets for small packets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...) + handler.handlePacket(addr, getPacketBuffer(), p) + Consistently(conn.dataWritten).ShouldNot(Receive()) + }) + }) + + Context("if no key is configured", func() { + It("doesn't send stateless resets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, 100)...) + handler.handlePacket(addr, getPacketBuffer(), p) + Consistently(conn.dataWritten).ShouldNot(Receive()) + }) + }) + }) })