only send stateless resets if a stateless reset key is configured

This commit is contained in:
Marten Seemann
2019-03-07 09:25:46 +09:00
parent 5c20519743
commit b3fe0fdbf9
3 changed files with 127 additions and 78 deletions

View File

@@ -209,6 +209,7 @@ type Config struct {
// If set to a negative value, it doesn't allow any unidirectional streams.
MaxIncomingUniStreams int
// The StatelessResetKey is used to generate stateless reset tokens.
// If no key is configured, sending of stateless resets is disabled.
StatelessResetKey []byte
// KeepAlive defines whether this peer will periodically send a packet to keep the connection alive.
KeepAlive bool

View File

@@ -34,6 +34,7 @@ type packetHandlerMap struct {
deleteRetiredSessionsAfter time.Duration
statelessResetEnabled bool
statelessResetHasher hash.Hash
logger utils.Logger
@@ -54,6 +55,7 @@ func newPacketHandlerMap(
handlers: make(map[string]packetHandler),
resetTokens: make(map[[16]byte]packetHandler),
deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout,
statelessResetEnabled: len(statelessResetKey) > 0,
statelessResetHasher: hmac.New(sha256.New, statelessResetKey),
logger: logger,
}
@@ -236,8 +238,15 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
}
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte {
h.statelessResetHasher.Write(connID.Bytes())
var token [16]byte
if !h.statelessResetEnabled {
// Return a random stateless reset token.
// This token will be sent in the server's transport parameters.
// By using a random token, an off-path attacker won't be able to disrupt the connection.
rand.Read(token[:])
return token
}
h.statelessResetHasher.Write(connID.Bytes())
copy(token[:], h.statelessResetHasher.Sum(nil))
h.statelessResetHasher.Reset()
return token
@@ -245,6 +254,9 @@ func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID)
func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
defer p.buffer.Release()
if !h.statelessResetEnabled {
return
}
// Don't send a stateless reset in response to very small packets.
// This includes packets that could be stateless resets.
if len(p.data) <= protocol.MinStatelessResetSize {

View File

@@ -2,6 +2,7 @@ package quic
import (
"bytes"
"crypto/rand"
"errors"
"net"
"time"
@@ -18,6 +19,9 @@ var _ = Describe("Packet Handler Map", func() {
var (
handler *packetHandlerMap
conn *mockPacketConn
connIDLen int
statelessResetKey []byte
)
getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) []byte {
@@ -40,8 +44,13 @@ var _ = Describe("Packet Handler Map", func() {
}
BeforeEach(func() {
statelessResetKey = nil
connIDLen = 0
})
JustBeforeEach(func() {
conn = newMockPacketConn()
handler = newPacketHandlerMap(conn, 5, nil, utils.DefaultLogger).(*packetHandlerMap)
handler = newPacketHandlerMap(conn, connIDLen, statelessResetKey, utils.DefaultLogger).(*packetHandlerMap)
})
AfterEach(func() {
@@ -80,6 +89,10 @@ var _ = Describe("Packet Handler Map", func() {
})
Context("handling packets", func() {
BeforeEach(func() {
connIDLen = 5
})
It("handles packets for different packet handlers on the same packet conn", func() {
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
@@ -163,15 +176,49 @@ var _ = Describe("Packet Handler Map", func() {
})
})
Context("stateless reset handling", func() {
It("generates stateless reset tokens", func() {
connID1 := []byte{0xde, 0xad, 0xbe, 0xef}
connID2 := []byte{0xde, 0xca, 0xfb, 0xad}
token1 := handler.GetStatelessResetToken(connID1)
Expect(handler.GetStatelessResetToken(connID1)).To(Equal(token1))
Expect(handler.GetStatelessResetToken(connID2)).ToNot(Equal(token1))
Context("running a server", func() {
It("adds a server", func() {
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
cid, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(cid).To(Equal(connID))
})
handler.SetServer(server)
handler.handlePacket(nil, nil, p)
})
It("closes all server sessions", func() {
clientSess := NewMockPacketHandler(mockCtrl)
clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
serverSess := NewMockPacketHandler(mockCtrl)
serverSess.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
serverSess.EXPECT().Close()
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess)
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess)
handler.CloseServer()
})
It("stops handling packets with unknown connection IDs after the server is closed", func() {
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
// don't EXPECT any calls to server.handlePacket
handler.SetServer(server)
handler.CloseServer()
handler.handlePacket(nil, nil, p)
})
})
Context("stateless resets", func() {
BeforeEach(func() {
connIDLen = 5
})
Context("handling", func() {
It("handles stateless resets", func() {
packetHandler := NewMockPacketHandler(mockCtrl)
token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
@@ -218,6 +265,22 @@ var _ = Describe("Packet Handler Map", func() {
// make sure we give it enough time to be called to cause an error here
time.Sleep(scaleDuration(25 * time.Millisecond))
})
})
Context("generating", func() {
BeforeEach(func() {
key := make([]byte, 32)
rand.Read(key)
statelessResetKey = key
})
It("generates stateless reset tokens", func() {
connID1 := []byte{0xde, 0xad, 0xbe, 0xef}
connID2 := []byte{0xde, 0xca, 0xfb, 0xad}
token1 := handler.GetStatelessResetToken(connID1)
Expect(handler.GetStatelessResetToken(connID1)).To(Equal(token1))
Expect(handler.GetStatelessResetToken(connID2)).ToNot(Equal(token1))
})
It("sends stateless resets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
@@ -238,40 +301,13 @@ var _ = Describe("Packet Handler Map", func() {
})
})
Context("running a server", func() {
It("adds a server", func() {
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
cid, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(cid).To(Equal(connID))
})
handler.SetServer(server)
handler.handlePacket(nil, nil, p)
})
It("closes all server sessions", func() {
clientSess := NewMockPacketHandler(mockCtrl)
clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
serverSess := NewMockPacketHandler(mockCtrl)
serverSess.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
serverSess.EXPECT().Close()
handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess)
handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess)
handler.CloseServer()
})
It("stops handling packets with unknown connection IDs after the server is closed", func() {
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
// don't EXPECT any calls to server.handlePacket
handler.SetServer(server)
handler.CloseServer()
handler.handlePacket(nil, nil, p)
Context("if no key is configured", func() {
It("doesn't send stateless resets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
p := append([]byte{40}, make([]byte, 100)...)
handler.handlePacket(addr, getPacketBuffer(), p)
Consistently(conn.dataWritten).ShouldNot(Receive())
})
})
})
})