detect stateless resets for zero-length connection IDs (#5027)

This commit is contained in:
Marten Seemann
2025-04-11 21:23:53 +08:00
committed by GitHub
parent e76621f75a
commit ef2b87f5d5
5 changed files with 95 additions and 28 deletions

View File

@@ -294,6 +294,22 @@ func (h *connIDManager) RetireConnIDForPath(pathID pathID) {
delete(h.pathProbing, pathID)
}
func (h *connIDManager) IsActiveStatelessResetToken(token protocol.StatelessResetToken) bool {
if h.activeStatelessResetToken != nil {
if *h.activeStatelessResetToken == token {
return true
}
}
if h.pathProbing != nil {
for _, entry := range h.pathProbing {
if entry.StatelessResetToken == token {
return true
}
}
}
return false
}
// Using the connIDManager after it has been closed can have disastrous effects:
// If the connection ID is rotated, a new entry would be inserted into the packet handler map,
// leading to a memory leak of the connection struct.

View File

@@ -154,22 +154,34 @@ func TestConnIDManagerHandshakeCompletion(t *testing.T) {
}
func TestConnIDManagerConnIDRotation(t *testing.T) {
toToken := func(connID protocol.ConnectionID) protocol.StatelessResetToken {
var token protocol.StatelessResetToken
copy(token[:], connID.Bytes())
copy(token[connID.Len():], connID.Bytes())
return token
}
var frameQueue []wire.Frame
var addedTokens, removedTokens []protocol.StatelessResetToken
m := newConnIDManager(
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
func(protocol.StatelessResetToken) {},
func(protocol.StatelessResetToken) {},
func(token protocol.StatelessResetToken) { addedTokens = append(addedTokens, token) },
func(token protocol.StatelessResetToken) { removedTokens = append(removedTokens, token) },
func(f wire.Frame) { frameQueue = append(frameQueue, f) },
)
// the first connection ID is used as soon as the handshake is complete
m.SetHandshakeComplete()
firstConnID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 1,
ConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}),
StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
ConnectionID: firstConnID,
StatelessResetToken: toToken(protocol.ParseConnectionID([]byte{4, 3, 2, 1})),
}))
require.Equal(t, protocol.ParseConnectionID([]byte{4, 3, 2, 1}), m.Get())
require.Equal(t, firstConnID, m.Get())
frameQueue = nil
require.True(t, m.IsActiveStatelessResetToken(toToken(firstConnID)))
require.Equal(t, addedTokens, []protocol.StatelessResetToken{toToken(firstConnID)})
addedTokens = addedTokens[:0]
// Note that we're missing the connection ID with sequence number 2.
// It will be received later.
@@ -182,8 +194,9 @@ func TestConnIDManagerConnIDRotation(t *testing.T) {
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: uint64(3 + i),
ConnectionID: connID,
StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
StatelessResetToken: toToken(connID),
}))
require.False(t, m.IsActiveStatelessResetToken(toToken(connID)))
}
var counter int
@@ -191,11 +204,19 @@ func TestConnIDManagerConnIDRotation(t *testing.T) {
require.Empty(t, frameQueue)
m.SentPacket()
counter++
if m.Get() != protocol.ParseConnectionID([]byte{4, 3, 2, 1}) {
if connID := m.Get(); connID != firstConnID {
require.Equal(t, queuedConnIDs[0], m.Get())
require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 1}}, frameQueue)
require.Equal(t, removedTokens, []protocol.StatelessResetToken{toToken(firstConnID)})
require.Equal(t, addedTokens, []protocol.StatelessResetToken{toToken(connID)})
addedTokens = addedTokens[:0]
removedTokens = removedTokens[:0]
require.True(t, m.IsActiveStatelessResetToken(toToken(connID)))
require.False(t, m.IsActiveStatelessResetToken(toToken(firstConnID)))
break
}
require.True(t, m.IsActiveStatelessResetToken(toToken(firstConnID)))
require.Empty(t, addedTokens)
}
require.GreaterOrEqual(t, counter, protocol.PacketsPerConnectionID/2)
require.LessOrEqual(t, counter, protocol.PacketsPerConnectionID*3/2)
@@ -223,7 +244,7 @@ func TestConnIDManagerPathMigration(t *testing.T) {
_, ok := m.GetConnIDForPath(1)
require.False(t, ok)
// add a connection ID
// add two connection IDs
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 1,
ConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}),
@@ -241,11 +262,13 @@ func TestConnIDManagerPathMigration(t *testing.T) {
require.Empty(t, removedTokens)
addedTokens = addedTokens[:0]
require.False(t, m.IsActiveStatelessResetToken(protocol.StatelessResetToken{5, 4, 3, 2, 5, 4, 3, 2}))
connID, ok = m.GetConnIDForPath(2)
require.True(t, ok)
require.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2}), connID)
require.Equal(t, []protocol.StatelessResetToken{{5, 4, 3, 2, 5, 4, 3, 2}}, addedTokens)
require.Empty(t, removedTokens)
require.True(t, m.IsActiveStatelessResetToken(protocol.StatelessResetToken{5, 4, 3, 2, 5, 4, 3, 2}))
addedTokens = addedTokens[:0]
// asking for the connection for path 1 again returns the same connection ID
@@ -294,12 +317,14 @@ func TestConnIDManagerPathMigration(t *testing.T) {
require.Equal(t, protocol.ParseConnectionID([]byte{7, 6, 5, 4}), connID)
require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13}}, addedTokens)
require.Empty(t, removedTokens)
require.True(t, m.IsActiveStatelessResetToken(protocol.StatelessResetToken{16, 15, 14, 13}))
// a RETIRE_CONNECTION_ID frame for path 1 is queued when retiring the connection ID
m.RetireConnIDForPath(1)
require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 4}}, frameQueue)
require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13}}, removedTokens)
removedTokens = removedTokens[:0]
require.False(t, m.IsActiveStatelessResetToken(protocol.StatelessResetToken{16, 15, 14, 13}))
m.Close()
require.Equal(t, []protocol.StatelessResetToken{

View File

@@ -977,7 +977,7 @@ func (s *connection) handleOnePacket(rp receivedPacket) (wasProcessed bool, _ er
if counter > 0 {
p.buffer.Split()
}
processed, err := s.handleShortHeaderPacket(p)
processed, err := s.handleShortHeaderPacket(p, counter > 0)
if err != nil {
return false, err
}
@@ -992,7 +992,7 @@ func (s *connection) handleOnePacket(rp receivedPacket) (wasProcessed bool, _ er
return wasProcessed, nil
}
func (s *connection) handleShortHeaderPacket(p receivedPacket) (wasProcessed bool, _ error) {
func (s *connection) handleShortHeaderPacket(p receivedPacket, isCoalesced bool) (wasProcessed bool, _ error) {
var wasQueued bool
defer func() {
@@ -1009,6 +1009,17 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket) (wasProcessed boo
}
pn, pnLen, keyPhase, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data)
if err != nil {
// Stateless reset packets (see RFC 9000, section 10.3):
// * fill the entire UDP datagram (i.e. they cannot be part of a coalesced packet)
// * are short header packets (first bit is 0)
// * have the QUIC bit set (second bit is 1)
// * are at least 21 bytes long
if !isCoalesced && len(p.data) >= protocol.MinReceivedStatelessResetSize && p.data[0]&0b11000000 == 0b01000000 {
token := protocol.StatelessResetToken(p.data[len(p.data)-16:])
if s.connIDManager.IsActiveStatelessResetToken(token) {
return false, &StatelessResetError{}
}
}
wasQueued, err = s.handleUnpackError(err, p, logging.PacketType1RTT)
return false, err
}

View File

@@ -14,7 +14,7 @@ import (
)
func TestStatelessResets(t *testing.T) {
t.Run("0 byte connection IDs", func(t *testing.T) {
t.Run("zero-length connection IDs", func(t *testing.T) {
testStatelessReset(t, 0)
})
t.Run("10 byte connection IDs", func(t *testing.T) {
@@ -28,9 +28,8 @@ func testStatelessReset(t *testing.T, connIDLen int) {
c := newUDPConnLocalhost(t)
tr := &quic.Transport{
Conn: c,
StatelessResetKey: &statelessResetKey,
ConnectionIDLength: connIDLen,
Conn: c,
StatelessResetKey: &statelessResetKey,
}
defer tr.Close()
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
@@ -65,18 +64,31 @@ func testStatelessReset(t *testing.T, connIDLen int) {
require.NoError(t, proxy.Start())
defer proxy.Close()
cl := &quic.Transport{
Conn: newUDPConnLocalhost(t),
ConnectionIDLength: connIDLen,
var conn quic.Connection
if connIDLen > 0 {
cl := &quic.Transport{
Conn: newUDPConnLocalhost(t),
ConnectionIDLength: connIDLen,
}
defer cl.Close()
var err error
conn, err = cl.Dial(
context.Background(),
proxy.LocalAddr(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
)
require.NoError(t, err)
} else {
conn, err = quic.Dial(
context.Background(),
newUDPConnLocalhost(t),
proxy.LocalAddr(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
)
require.NoError(t, err)
}
defer cl.Close()
conn, err := cl.Dial(
context.Background(),
proxy.LocalAddr(),
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
)
require.NoError(t, err)
str, err := conn.AcceptStream(context.Background())
require.NoError(t, err)
data := make([]byte, 6)
@@ -94,9 +106,8 @@ func testStatelessReset(t *testing.T, connIDLen int) {
// We need to create a new Transport here, since the old one is still sending out
// CONNECTION_CLOSE packets for (recently) closed connections).
tr2 := &quic.Transport{
Conn: c,
ConnectionIDLength: connIDLen,
StatelessResetKey: &statelessResetKey,
Conn: c,
StatelessResetKey: &statelessResetKey,
}
defer tr2.Close()
ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil))

View File

@@ -123,6 +123,10 @@ const MinUnknownVersionPacketSize = MinInitialPacketSize
// MinStatelessResetSize is the minimum size of a stateless reset packet that we send
const MinStatelessResetSize = 1 /* first byte */ + 20 /* max. conn ID length */ + 4 /* max. packet number length */ + 1 /* min. payload length */ + 16 /* token */
// MinReceivedStatelessResetSize is the minimum size of a received stateless reset,
// as specified in section 10.3 of RFC 9000.
const MinReceivedStatelessResetSize = 5 + 16
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
const MinConnectionIDLenInitial = 8