use Transport.VerifySourceAddress to control the Retry Mechanism (#4362)

* use Transport.VerifySourceAddress to control the Retry Mechanism

This can be used to rate-limit handshakes originating from unverified
source addresses. Rate-limiting for handshakes can be implemented using
the GetConfigForClient callback on the Config.

* pass the remote address to Transport.VerifySourceAddress
This commit is contained in:
Marten Seemann
2024-03-15 10:05:52 +09:30
committed by GitHub
parent 497d3f58a5
commit 9971fedd42
12 changed files with 120 additions and 382 deletions

148
server.go
View File

@@ -7,7 +7,6 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/handshake"
@@ -108,10 +107,7 @@ type baseServer struct {
connectionRefusedQueue chan rejectedPacket
retryQueue chan rejectedPacket
maxNumHandshakesUnvalidated int
maxNumHandshakesTotal int
numHandshakesUnvalidated atomic.Int64
numHandshakesValidated atomic.Int64
verifySourceAddress func(net.Addr) bool
connQueue chan quicConn
@@ -241,34 +237,33 @@ func newServer(
onClose func(),
tokenGeneratorKey TokenGeneratorKey,
maxTokenAge time.Duration,
maxNumHandshakesUnvalidated, maxNumHandshakesTotal int,
verifySourceAddress func(net.Addr) bool,
disableVersionNegotiation bool,
acceptEarly bool,
) *baseServer {
s := &baseServer{
conn: conn,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
maxTokenAge: maxTokenAge,
maxNumHandshakesUnvalidated: maxNumHandshakesUnvalidated,
maxNumHandshakesTotal: maxNumHandshakesTotal,
connIDGenerator: connIDGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
versionNegotiationQueue: make(chan receivedPacket, 4),
invalidTokenQueue: make(chan rejectedPacket, 4),
connectionRefusedQueue: make(chan rejectedPacket, 4),
retryQueue: make(chan rejectedPacket, 8),
newConn: newConnection,
tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
disableVersionNegotiation: disableVersionNegotiation,
onClose: onClose,
conn: conn,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
maxTokenAge: maxTokenAge,
verifySourceAddress: verifySourceAddress,
connIDGenerator: connIDGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
versionNegotiationQueue: make(chan receivedPacket, 4),
invalidTokenQueue: make(chan rejectedPacket, 4),
connectionRefusedQueue: make(chan rejectedPacket, 4),
retryQueue: make(chan rejectedPacket, 8),
newConn: newConnection,
tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
disableVersionNegotiation: disableVersionNegotiation,
onClose: onClose,
}
if acceptEarly {
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
@@ -565,8 +560,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
var (
token *handshake.Token
retrySrcConnID *protocol.ConnectionID
token *handshake.Token
retrySrcConnID *protocol.ConnectionID
clientAddrVerified bool
)
origDestConnID := hdr.DestConnectionID
if len(hdr.Token) > 0 {
@@ -579,46 +575,30 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
token = tok
}
}
clientAddrValidated := s.validateToken(token, p.remoteAddr)
if token != nil && !clientAddrValidated {
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
// We just ignore them, and act as if there was no token on this packet at all.
// This also means we might send a Retry later.
if !token.IsRetryToken {
token = nil
} else {
// For Retry tokens, we send an INVALID_ERROR if
// * the token is too old, or
// * the token is invalid, in case of a retry token.
select {
case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the INVALID_TOKEN packets fast enough
p.buffer.Release()
if token != nil {
clientAddrVerified = s.validateToken(token, p.remoteAddr)
if !clientAddrVerified {
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
// We just ignore them, and act as if there was no token on this packet at all.
// This also means we might send a Retry later.
if !token.IsRetryToken {
token = nil
} else {
// For Retry tokens, we send an INVALID_ERROR if
// * the token is too old, or
// * the token is invalid, in case of a retry token.
select {
case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the INVALID_TOKEN packets fast enough
p.buffer.Release()
}
return nil
}
return nil
}
}
// Until the next call to handleInitialImpl, these numbers are guaranteed to not increase.
// They might decrease if another connection completes the handshake.
numHandshakesUnvalidated := s.numHandshakesUnvalidated.Load()
numHandshakesValidated := s.numHandshakesValidated.Load()
// Check the total handshake limit first. It's better to reject than to initiate a retry.
if total := numHandshakesUnvalidated + numHandshakesValidated; total >= int64(s.maxNumHandshakesTotal) {
s.logger.Debugf("Rejecting new connection. Server currently busy. Currently handshaking: %d (max %d)", total, s.maxNumHandshakesTotal)
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
p.buffer.Release()
}
return nil
}
if token == nil && numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated) {
if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress(p.remoteAddr) {
// Retry invalidates all 0-RTT packets sent.
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
@@ -630,18 +610,11 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
return nil
}
connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
return err
}
s.logger.Debugf("Changing connection ID to %s.", connID)
var conn quicConn
tracingID := nextConnTracingID()
config := s.config
if s.config.GetConfigForClient != nil {
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{
RemoteAddr: p.remoteAddr,
AddrVerified: clientAddrValidated,
AddrVerified: clientAddrVerified,
})
if err != nil {
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
@@ -656,6 +629,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
config = populateConfig(conf)
}
var conn quicConn
tracingID := nextConnTracingID()
var tracer *logging.ConnectionTracer
if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
@@ -665,6 +641,11 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
}
connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
return err
}
s.logger.Debugf("Changing connection ID to %s.", connID)
conn = s.newConn(
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.connHandler,
@@ -678,7 +659,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
config,
s.tlsConf,
s.tokenGenerator,
clientAddrValidated,
clientAddrVerified,
tracer,
tracingID,
s.logger,
@@ -702,22 +683,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
if clientAddrValidated {
s.numHandshakesValidated.Add(1)
} else {
s.numHandshakesUnvalidated.Add(1)
}
go conn.run()
go func() {
completed := s.handleNewConn(conn)
if clientAddrValidated {
if s.numHandshakesValidated.Add(-1) < 0 {
panic("server BUG: number of validated handshakes negative")
}
} else if s.numHandshakesUnvalidated.Add(-1) < 0 {
panic("server BUG: number of unvalidated handshakes negative")
}
if !completed {
if completed := s.handleNewConn(conn); !completed {
return
}