forked from quic-go/quic-go
detect stateless resets for zero-length connection IDs (#5027)
This commit is contained in:
@@ -294,6 +294,22 @@ func (h *connIDManager) RetireConnIDForPath(pathID pathID) {
|
||||
delete(h.pathProbing, pathID)
|
||||
}
|
||||
|
||||
func (h *connIDManager) IsActiveStatelessResetToken(token protocol.StatelessResetToken) bool {
|
||||
if h.activeStatelessResetToken != nil {
|
||||
if *h.activeStatelessResetToken == token {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if h.pathProbing != nil {
|
||||
for _, entry := range h.pathProbing {
|
||||
if entry.StatelessResetToken == token {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Using the connIDManager after it has been closed can have disastrous effects:
|
||||
// If the connection ID is rotated, a new entry would be inserted into the packet handler map,
|
||||
// leading to a memory leak of the connection struct.
|
||||
|
||||
@@ -154,22 +154,34 @@ func TestConnIDManagerHandshakeCompletion(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConnIDManagerConnIDRotation(t *testing.T) {
|
||||
toToken := func(connID protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
var token protocol.StatelessResetToken
|
||||
copy(token[:], connID.Bytes())
|
||||
copy(token[connID.Len():], connID.Bytes())
|
||||
return token
|
||||
}
|
||||
|
||||
var frameQueue []wire.Frame
|
||||
var addedTokens, removedTokens []protocol.StatelessResetToken
|
||||
m := newConnIDManager(
|
||||
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
||||
func(protocol.StatelessResetToken) {},
|
||||
func(protocol.StatelessResetToken) {},
|
||||
func(token protocol.StatelessResetToken) { addedTokens = append(addedTokens, token) },
|
||||
func(token protocol.StatelessResetToken) { removedTokens = append(removedTokens, token) },
|
||||
func(f wire.Frame) { frameQueue = append(frameQueue, f) },
|
||||
)
|
||||
// the first connection ID is used as soon as the handshake is complete
|
||||
m.SetHandshakeComplete()
|
||||
firstConnID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
|
||||
SequenceNumber: 1,
|
||||
ConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}),
|
||||
StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
|
||||
ConnectionID: firstConnID,
|
||||
StatelessResetToken: toToken(protocol.ParseConnectionID([]byte{4, 3, 2, 1})),
|
||||
}))
|
||||
require.Equal(t, protocol.ParseConnectionID([]byte{4, 3, 2, 1}), m.Get())
|
||||
require.Equal(t, firstConnID, m.Get())
|
||||
frameQueue = nil
|
||||
require.True(t, m.IsActiveStatelessResetToken(toToken(firstConnID)))
|
||||
require.Equal(t, addedTokens, []protocol.StatelessResetToken{toToken(firstConnID)})
|
||||
addedTokens = addedTokens[:0]
|
||||
|
||||
// Note that we're missing the connection ID with sequence number 2.
|
||||
// It will be received later.
|
||||
@@ -182,8 +194,9 @@ func TestConnIDManagerConnIDRotation(t *testing.T) {
|
||||
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
|
||||
SequenceNumber: uint64(3 + i),
|
||||
ConnectionID: connID,
|
||||
StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
|
||||
StatelessResetToken: toToken(connID),
|
||||
}))
|
||||
require.False(t, m.IsActiveStatelessResetToken(toToken(connID)))
|
||||
}
|
||||
|
||||
var counter int
|
||||
@@ -191,11 +204,19 @@ func TestConnIDManagerConnIDRotation(t *testing.T) {
|
||||
require.Empty(t, frameQueue)
|
||||
m.SentPacket()
|
||||
counter++
|
||||
if m.Get() != protocol.ParseConnectionID([]byte{4, 3, 2, 1}) {
|
||||
if connID := m.Get(); connID != firstConnID {
|
||||
require.Equal(t, queuedConnIDs[0], m.Get())
|
||||
require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 1}}, frameQueue)
|
||||
require.Equal(t, removedTokens, []protocol.StatelessResetToken{toToken(firstConnID)})
|
||||
require.Equal(t, addedTokens, []protocol.StatelessResetToken{toToken(connID)})
|
||||
addedTokens = addedTokens[:0]
|
||||
removedTokens = removedTokens[:0]
|
||||
require.True(t, m.IsActiveStatelessResetToken(toToken(connID)))
|
||||
require.False(t, m.IsActiveStatelessResetToken(toToken(firstConnID)))
|
||||
break
|
||||
}
|
||||
require.True(t, m.IsActiveStatelessResetToken(toToken(firstConnID)))
|
||||
require.Empty(t, addedTokens)
|
||||
}
|
||||
require.GreaterOrEqual(t, counter, protocol.PacketsPerConnectionID/2)
|
||||
require.LessOrEqual(t, counter, protocol.PacketsPerConnectionID*3/2)
|
||||
@@ -223,7 +244,7 @@ func TestConnIDManagerPathMigration(t *testing.T) {
|
||||
_, ok := m.GetConnIDForPath(1)
|
||||
require.False(t, ok)
|
||||
|
||||
// add a connection ID
|
||||
// add two connection IDs
|
||||
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
|
||||
SequenceNumber: 1,
|
||||
ConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}),
|
||||
@@ -241,11 +262,13 @@ func TestConnIDManagerPathMigration(t *testing.T) {
|
||||
require.Empty(t, removedTokens)
|
||||
|
||||
addedTokens = addedTokens[:0]
|
||||
require.False(t, m.IsActiveStatelessResetToken(protocol.StatelessResetToken{5, 4, 3, 2, 5, 4, 3, 2}))
|
||||
connID, ok = m.GetConnIDForPath(2)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2}), connID)
|
||||
require.Equal(t, []protocol.StatelessResetToken{{5, 4, 3, 2, 5, 4, 3, 2}}, addedTokens)
|
||||
require.Empty(t, removedTokens)
|
||||
require.True(t, m.IsActiveStatelessResetToken(protocol.StatelessResetToken{5, 4, 3, 2, 5, 4, 3, 2}))
|
||||
|
||||
addedTokens = addedTokens[:0]
|
||||
// asking for the connection for path 1 again returns the same connection ID
|
||||
@@ -294,12 +317,14 @@ func TestConnIDManagerPathMigration(t *testing.T) {
|
||||
require.Equal(t, protocol.ParseConnectionID([]byte{7, 6, 5, 4}), connID)
|
||||
require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13}}, addedTokens)
|
||||
require.Empty(t, removedTokens)
|
||||
require.True(t, m.IsActiveStatelessResetToken(protocol.StatelessResetToken{16, 15, 14, 13}))
|
||||
|
||||
// a RETIRE_CONNECTION_ID frame for path 1 is queued when retiring the connection ID
|
||||
m.RetireConnIDForPath(1)
|
||||
require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 4}}, frameQueue)
|
||||
require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13}}, removedTokens)
|
||||
removedTokens = removedTokens[:0]
|
||||
require.False(t, m.IsActiveStatelessResetToken(protocol.StatelessResetToken{16, 15, 14, 13}))
|
||||
|
||||
m.Close()
|
||||
require.Equal(t, []protocol.StatelessResetToken{
|
||||
|
||||
@@ -977,7 +977,7 @@ func (s *connection) handleOnePacket(rp receivedPacket) (wasProcessed bool, _ er
|
||||
if counter > 0 {
|
||||
p.buffer.Split()
|
||||
}
|
||||
processed, err := s.handleShortHeaderPacket(p)
|
||||
processed, err := s.handleShortHeaderPacket(p, counter > 0)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -992,7 +992,7 @@ func (s *connection) handleOnePacket(rp receivedPacket) (wasProcessed bool, _ er
|
||||
return wasProcessed, nil
|
||||
}
|
||||
|
||||
func (s *connection) handleShortHeaderPacket(p receivedPacket) (wasProcessed bool, _ error) {
|
||||
func (s *connection) handleShortHeaderPacket(p receivedPacket, isCoalesced bool) (wasProcessed bool, _ error) {
|
||||
var wasQueued bool
|
||||
|
||||
defer func() {
|
||||
@@ -1009,6 +1009,17 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket) (wasProcessed boo
|
||||
}
|
||||
pn, pnLen, keyPhase, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data)
|
||||
if err != nil {
|
||||
// Stateless reset packets (see RFC 9000, section 10.3):
|
||||
// * fill the entire UDP datagram (i.e. they cannot be part of a coalesced packet)
|
||||
// * are short header packets (first bit is 0)
|
||||
// * have the QUIC bit set (second bit is 1)
|
||||
// * are at least 21 bytes long
|
||||
if !isCoalesced && len(p.data) >= protocol.MinReceivedStatelessResetSize && p.data[0]&0b11000000 == 0b01000000 {
|
||||
token := protocol.StatelessResetToken(p.data[len(p.data)-16:])
|
||||
if s.connIDManager.IsActiveStatelessResetToken(token) {
|
||||
return false, &StatelessResetError{}
|
||||
}
|
||||
}
|
||||
wasQueued, err = s.handleUnpackError(err, p, logging.PacketType1RTT)
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
func TestStatelessResets(t *testing.T) {
|
||||
t.Run("0 byte connection IDs", func(t *testing.T) {
|
||||
t.Run("zero-length connection IDs", func(t *testing.T) {
|
||||
testStatelessReset(t, 0)
|
||||
})
|
||||
t.Run("10 byte connection IDs", func(t *testing.T) {
|
||||
@@ -28,9 +28,8 @@ func testStatelessReset(t *testing.T, connIDLen int) {
|
||||
|
||||
c := newUDPConnLocalhost(t)
|
||||
tr := &quic.Transport{
|
||||
Conn: c,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
ConnectionIDLength: connIDLen,
|
||||
Conn: c,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
}
|
||||
defer tr.Close()
|
||||
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
@@ -65,18 +64,31 @@ func testStatelessReset(t *testing.T, connIDLen int) {
|
||||
require.NoError(t, proxy.Start())
|
||||
defer proxy.Close()
|
||||
|
||||
cl := &quic.Transport{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ConnectionIDLength: connIDLen,
|
||||
var conn quic.Connection
|
||||
if connIDLen > 0 {
|
||||
cl := &quic.Transport{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
defer cl.Close()
|
||||
var err error
|
||||
conn, err = cl.Dial(
|
||||
context.Background(),
|
||||
proxy.LocalAddr(),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
conn, err = quic.Dial(
|
||||
context.Background(),
|
||||
newUDPConnLocalhost(t),
|
||||
proxy.LocalAddr(),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
defer cl.Close()
|
||||
conn, err := cl.Dial(
|
||||
context.Background(),
|
||||
proxy.LocalAddr(),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
require.NoError(t, err)
|
||||
data := make([]byte, 6)
|
||||
@@ -94,9 +106,8 @@ func testStatelessReset(t *testing.T, connIDLen int) {
|
||||
// We need to create a new Transport here, since the old one is still sending out
|
||||
// CONNECTION_CLOSE packets for (recently) closed connections).
|
||||
tr2 := &quic.Transport{
|
||||
Conn: c,
|
||||
ConnectionIDLength: connIDLen,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
Conn: c,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
}
|
||||
defer tr2.Close()
|
||||
ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
|
||||
@@ -123,6 +123,10 @@ const MinUnknownVersionPacketSize = MinInitialPacketSize
|
||||
// MinStatelessResetSize is the minimum size of a stateless reset packet that we send
|
||||
const MinStatelessResetSize = 1 /* first byte */ + 20 /* max. conn ID length */ + 4 /* max. packet number length */ + 1 /* min. payload length */ + 16 /* token */
|
||||
|
||||
// MinReceivedStatelessResetSize is the minimum size of a received stateless reset,
|
||||
// as specified in section 10.3 of RFC 9000.
|
||||
const MinReceivedStatelessResetSize = 5 + 16
|
||||
|
||||
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
|
||||
const MinConnectionIDLenInitial = 8
|
||||
|
||||
|
||||
Reference in New Issue
Block a user