pass ClientInfo to Transport.ConnContext, allow rejection of handshake (#5122)

* pass *ClientInfo to ConnContext

This allows users to set some state related to the client's remote
address in `ConnContext`.

This also refuses connection early if the returned context is cancelled.

* review comments

* better error assertion

---------

Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
sukun
2025-05-16 15:33:05 +05:30
committed by GitHub
parent 77c401dcbb
commit eb8547ca73
5 changed files with 83 additions and 23 deletions

View File

@@ -74,7 +74,7 @@ type baseServer struct {
nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
connContext func(context.Context) context.Context
connContext func(context.Context, *ClientInfo) (context.Context, error)
// set as a member, so they can be set in the tests
newConn func(
@@ -240,7 +240,7 @@ func newServer(
tr *packetHandlerMap,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
connContext func(context.Context) context.Context,
connContext func(context.Context, *ClientInfo) (context.Context, error),
tlsConf *tls.Config,
config *Config,
tracer *logging.Tracer,
@@ -638,20 +638,15 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
config := s.config
clientInfo := &ClientInfo{
RemoteAddr: p.remoteAddr,
AddrVerified: clientAddrVerified,
}
if s.config.GetConfigForClient != nil {
conf, err := s.config.GetConfigForClient(&ClientInfo{
RemoteAddr: p.remoteAddr,
AddrVerified: clientAddrVerified,
})
conf, err := s.config.GetConfigForClient(clientInfo)
if err != nil {
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
p.buffer.Release()
}
s.refuseNewConn(p, hdr)
return nil
}
config = populateConfig(conf)
@@ -661,7 +656,14 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
var cancel context.CancelCauseFunc
ctx, cancel1 := context.WithCancelCause(context.Background())
if s.connContext != nil {
ctx = s.connContext(ctx)
var err error
ctx, err = s.connContext(ctx, clientInfo)
if err != nil {
cancel1(err)
s.logger.Debugf("Rejecting new connection due to ConnContext callback: %s", err)
s.refuseNewConn(p, hdr)
return nil
}
if ctx == nil {
panic("quic: ConnContext returned nil")
}
@@ -739,6 +741,16 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
return nil
}
func (s *baseServer) refuseNewConn(p receivedPacket, hdr *wire.Header) {
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
p.buffer.Release()
}
}
func (s *baseServer) handleNewConn(conn quicConn) {
if s.acceptEarlyConns {
// wait until the early connection is ready, the handshake fails, or the server is closed