forked from quic-go/quic-go
use Transport.VerifySourceAddress to control the Retry Mechanism (#4362)
* use Transport.VerifySourceAddress to control the Retry Mechanism This can be used to rate-limit handshakes originating from unverified source addresses. Rate-limiting for handshakes can be implemented using the GetConfigForClient callback on the Config. * pass the remote address to Transport.VerifySourceAddress
This commit is contained in:
148
server.go
148
server.go
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/handshake"
|
||||
@@ -108,10 +107,7 @@ type baseServer struct {
|
||||
connectionRefusedQueue chan rejectedPacket
|
||||
retryQueue chan rejectedPacket
|
||||
|
||||
maxNumHandshakesUnvalidated int
|
||||
maxNumHandshakesTotal int
|
||||
numHandshakesUnvalidated atomic.Int64
|
||||
numHandshakesValidated atomic.Int64
|
||||
verifySourceAddress func(net.Addr) bool
|
||||
|
||||
connQueue chan quicConn
|
||||
|
||||
@@ -241,34 +237,33 @@ func newServer(
|
||||
onClose func(),
|
||||
tokenGeneratorKey TokenGeneratorKey,
|
||||
maxTokenAge time.Duration,
|
||||
maxNumHandshakesUnvalidated, maxNumHandshakesTotal int,
|
||||
verifySourceAddress func(net.Addr) bool,
|
||||
disableVersionNegotiation bool,
|
||||
acceptEarly bool,
|
||||
) *baseServer {
|
||||
s := &baseServer{
|
||||
conn: conn,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
|
||||
maxTokenAge: maxTokenAge,
|
||||
maxNumHandshakesUnvalidated: maxNumHandshakesUnvalidated,
|
||||
maxNumHandshakesTotal: maxNumHandshakesTotal,
|
||||
connIDGenerator: connIDGenerator,
|
||||
connHandler: connHandler,
|
||||
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
|
||||
errorChan: make(chan struct{}),
|
||||
running: make(chan struct{}),
|
||||
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
|
||||
versionNegotiationQueue: make(chan receivedPacket, 4),
|
||||
invalidTokenQueue: make(chan rejectedPacket, 4),
|
||||
connectionRefusedQueue: make(chan rejectedPacket, 4),
|
||||
retryQueue: make(chan rejectedPacket, 8),
|
||||
newConn: newConnection,
|
||||
tracer: tracer,
|
||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||
acceptEarlyConns: acceptEarly,
|
||||
disableVersionNegotiation: disableVersionNegotiation,
|
||||
onClose: onClose,
|
||||
conn: conn,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
|
||||
maxTokenAge: maxTokenAge,
|
||||
verifySourceAddress: verifySourceAddress,
|
||||
connIDGenerator: connIDGenerator,
|
||||
connHandler: connHandler,
|
||||
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
|
||||
errorChan: make(chan struct{}),
|
||||
running: make(chan struct{}),
|
||||
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
|
||||
versionNegotiationQueue: make(chan receivedPacket, 4),
|
||||
invalidTokenQueue: make(chan rejectedPacket, 4),
|
||||
connectionRefusedQueue: make(chan rejectedPacket, 4),
|
||||
retryQueue: make(chan rejectedPacket, 8),
|
||||
newConn: newConnection,
|
||||
tracer: tracer,
|
||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||
acceptEarlyConns: acceptEarly,
|
||||
disableVersionNegotiation: disableVersionNegotiation,
|
||||
onClose: onClose,
|
||||
}
|
||||
if acceptEarly {
|
||||
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
|
||||
@@ -565,8 +560,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
}
|
||||
|
||||
var (
|
||||
token *handshake.Token
|
||||
retrySrcConnID *protocol.ConnectionID
|
||||
token *handshake.Token
|
||||
retrySrcConnID *protocol.ConnectionID
|
||||
clientAddrVerified bool
|
||||
)
|
||||
origDestConnID := hdr.DestConnectionID
|
||||
if len(hdr.Token) > 0 {
|
||||
@@ -579,46 +575,30 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
token = tok
|
||||
}
|
||||
}
|
||||
|
||||
clientAddrValidated := s.validateToken(token, p.remoteAddr)
|
||||
if token != nil && !clientAddrValidated {
|
||||
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
||||
// We just ignore them, and act as if there was no token on this packet at all.
|
||||
// This also means we might send a Retry later.
|
||||
if !token.IsRetryToken {
|
||||
token = nil
|
||||
} else {
|
||||
// For Retry tokens, we send an INVALID_ERROR if
|
||||
// * the token is too old, or
|
||||
// * the token is invalid, in case of a retry token.
|
||||
select {
|
||||
case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
||||
default:
|
||||
// drop packet if we can't send out the INVALID_TOKEN packets fast enough
|
||||
p.buffer.Release()
|
||||
if token != nil {
|
||||
clientAddrVerified = s.validateToken(token, p.remoteAddr)
|
||||
if !clientAddrVerified {
|
||||
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
||||
// We just ignore them, and act as if there was no token on this packet at all.
|
||||
// This also means we might send a Retry later.
|
||||
if !token.IsRetryToken {
|
||||
token = nil
|
||||
} else {
|
||||
// For Retry tokens, we send an INVALID_ERROR if
|
||||
// * the token is too old, or
|
||||
// * the token is invalid, in case of a retry token.
|
||||
select {
|
||||
case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
||||
default:
|
||||
// drop packet if we can't send out the INVALID_TOKEN packets fast enough
|
||||
p.buffer.Release()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Until the next call to handleInitialImpl, these numbers are guaranteed to not increase.
|
||||
// They might decrease if another connection completes the handshake.
|
||||
numHandshakesUnvalidated := s.numHandshakesUnvalidated.Load()
|
||||
numHandshakesValidated := s.numHandshakesValidated.Load()
|
||||
|
||||
// Check the total handshake limit first. It's better to reject than to initiate a retry.
|
||||
if total := numHandshakesUnvalidated + numHandshakesValidated; total >= int64(s.maxNumHandshakesTotal) {
|
||||
s.logger.Debugf("Rejecting new connection. Server currently busy. Currently handshaking: %d (max %d)", total, s.maxNumHandshakesTotal)
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
select {
|
||||
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
||||
default:
|
||||
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
|
||||
p.buffer.Release()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if token == nil && numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated) {
|
||||
if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress(p.remoteAddr) {
|
||||
// Retry invalidates all 0-RTT packets sent.
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
select {
|
||||
@@ -630,18 +610,11 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
return nil
|
||||
}
|
||||
|
||||
connID, err := s.connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Debugf("Changing connection ID to %s.", connID)
|
||||
var conn quicConn
|
||||
tracingID := nextConnTracingID()
|
||||
config := s.config
|
||||
if s.config.GetConfigForClient != nil {
|
||||
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{
|
||||
RemoteAddr: p.remoteAddr,
|
||||
AddrVerified: clientAddrValidated,
|
||||
AddrVerified: clientAddrVerified,
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
|
||||
@@ -656,6 +629,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
}
|
||||
config = populateConfig(conf)
|
||||
}
|
||||
|
||||
var conn quicConn
|
||||
tracingID := nextConnTracingID()
|
||||
var tracer *logging.ConnectionTracer
|
||||
if config.Tracer != nil {
|
||||
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
||||
@@ -665,6 +641,11 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
}
|
||||
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
|
||||
}
|
||||
connID, err := s.connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Debugf("Changing connection ID to %s.", connID)
|
||||
conn = s.newConn(
|
||||
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
|
||||
s.connHandler,
|
||||
@@ -678,7 +659,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
config,
|
||||
s.tlsConf,
|
||||
s.tokenGenerator,
|
||||
clientAddrValidated,
|
||||
clientAddrVerified,
|
||||
tracer,
|
||||
tracingID,
|
||||
s.logger,
|
||||
@@ -702,22 +683,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
}
|
||||
|
||||
if clientAddrValidated {
|
||||
s.numHandshakesValidated.Add(1)
|
||||
} else {
|
||||
s.numHandshakesUnvalidated.Add(1)
|
||||
}
|
||||
go conn.run()
|
||||
go func() {
|
||||
completed := s.handleNewConn(conn)
|
||||
if clientAddrValidated {
|
||||
if s.numHandshakesValidated.Add(-1) < 0 {
|
||||
panic("server BUG: number of validated handshakes negative")
|
||||
}
|
||||
} else if s.numHandshakesUnvalidated.Add(-1) < 0 {
|
||||
panic("server BUG: number of unvalidated handshakes negative")
|
||||
}
|
||||
if !completed {
|
||||
if completed := s.handleNewConn(conn); !completed {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user