pass ClientInfo to Transport.ConnContext, allow rejection of handshake (#5122)

* pass *ClientInfo to ConnContext

This allows users to set some state related to the client's remote
address in `ConnContext`.

This also refuses connection early if the returned context is cancelled.

* review comments

* better error assertion

---------

Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
sukun
2025-05-16 15:33:05 +05:30
committed by GitHub
parent 77c401dcbb
commit eb8547ca73
5 changed files with 83 additions and 23 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"errors"
"net"
"testing"
"time"
@@ -50,8 +51,8 @@ func TestConnContextOnServerSide(t *testing.T) {
tr := &quic.Transport{
Conn: newUDPConnLocalhost(t),
ConnContext: func(ctx context.Context) context.Context {
return context.WithValue(ctx, "foo", "bar") //nolint:staticcheck
ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) {
return context.WithValue(ctx, "foo", "bar"), nil
},
}
defer tr.Close()
@@ -132,11 +133,57 @@ func TestConnContextOnServerSide(t *testing.T) {
checkContext(tlsGetCertificateContextChan, false)
}
func TestConnContextRejection(t *testing.T) {
t.Run("rejecting", func(t *testing.T) {
testConnContextRejection(t, true)
})
t.Run("not rejecting", func(t *testing.T) {
testConnContextRejection(t, false)
})
}
func testConnContextRejection(t *testing.T, reject bool) {
tr := &quic.Transport{
Conn: newUDPConnLocalhost(t),
ConnContext: func(ctx context.Context, ci *quic.ClientInfo) (context.Context, error) {
if reject {
return nil, errors.New("rejecting connection")
}
return context.WithValue(ctx, "addr", ci.RemoteAddr), nil
},
}
defer tr.Close()
server, err := tr.Listen(
getTLSConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
defer server.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
pc := newUDPConnLocalhost(t)
c, err := quic.Dial(ctx, pc, server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
if reject {
require.ErrorIs(t, err, &quic.TransportError{Remote: true, ErrorCode: quic.ConnectionRefused})
return
}
require.NoError(t, err)
defer c.CloseWithError(0, "")
conn, err := server.Accept(ctx)
require.NoError(t, err)
require.Equal(t, pc.LocalAddr().String(), conn.Context().Value("addr").(net.Addr).String())
conn.CloseWithError(0, "")
}
// Users are not supposed to return a fresh context from ConnContext, but we should handle it gracefully.
func TestConnContextFreshContext(t *testing.T) {
tr := &quic.Transport{
Conn: newUDPConnLocalhost(t),
ConnContext: func(ctx context.Context) context.Context { return context.Background() },
Conn: newUDPConnLocalhost(t),
ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) {
return context.Background(), nil
},
}
defer tr.Close()
server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))

View File

@@ -614,8 +614,8 @@ func TestHTTPContextFromQUIC(t *testing.T) {
conn := newUDPConnLocalhost(t)
tr := &quic.Transport{
Conn: conn,
ConnContext: func(ctx context.Context) context.Context {
return context.WithValue(ctx, "foo", "bar") //nolint:staticcheck
ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) {
return context.WithValue(ctx, "foo", "bar"), nil
},
}
defer tr.Close()

View File

@@ -74,7 +74,7 @@ type baseServer struct {
nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
connContext func(context.Context) context.Context
connContext func(context.Context, *ClientInfo) (context.Context, error)
// set as a member, so they can be set in the tests
newConn func(
@@ -240,7 +240,7 @@ func newServer(
tr *packetHandlerMap,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
connContext func(context.Context) context.Context,
connContext func(context.Context, *ClientInfo) (context.Context, error),
tlsConf *tls.Config,
config *Config,
tracer *logging.Tracer,
@@ -638,20 +638,15 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
config := s.config
clientInfo := &ClientInfo{
RemoteAddr: p.remoteAddr,
AddrVerified: clientAddrVerified,
}
if s.config.GetConfigForClient != nil {
conf, err := s.config.GetConfigForClient(&ClientInfo{
RemoteAddr: p.remoteAddr,
AddrVerified: clientAddrVerified,
})
conf, err := s.config.GetConfigForClient(clientInfo)
if err != nil {
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
p.buffer.Release()
}
s.refuseNewConn(p, hdr)
return nil
}
config = populateConfig(conf)
@@ -661,7 +656,14 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
var cancel context.CancelCauseFunc
ctx, cancel1 := context.WithCancelCause(context.Background())
if s.connContext != nil {
ctx = s.connContext(ctx)
var err error
ctx, err = s.connContext(ctx, clientInfo)
if err != nil {
cancel1(err)
s.logger.Debugf("Rejecting new connection due to ConnContext callback: %s", err)
s.refuseNewConn(p, hdr)
return nil
}
if ctx == nil {
panic("quic: ConnContext returned nil")
}
@@ -739,6 +741,16 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
return nil
}
func (s *baseServer) refuseNewConn(p receivedPacket, hdr *wire.Header) {
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
p.buffer.Release()
}
}
func (s *baseServer) handleNewConn(conn quicConn) {
if s.acceptEarlyConns {
// wait until the early connection is ready, the handshake fails, or the server is closed

View File

@@ -68,7 +68,7 @@ func newTestServer(t *testing.T, serverOpts *serverOpts) *testServer {
(*packetHandlerMap)(tr),
&protocol.DefaultConnectionIDGenerator{},
&statelessResetter{},
func(ctx context.Context) context.Context { return ctx },
func(ctx context.Context, _ *ClientInfo) (context.Context, error) { return ctx, nil },
&tls.Config{},
config,
serverOpts.tracer,

View File

@@ -117,7 +117,8 @@ type Transport struct {
// implementation of this callback (negating its return value).
VerifySourceAddress func(net.Addr) bool
// ConnContext is called when the server accepts a new connection.
// ConnContext is called when the server accepts a new connection. To reject a connection return
// a non-nil error.
// The context is closed when the connection is closed, or when the handshake fails for any reason.
// The context returned from the callback is used to derive every other context used during the
// lifetime of the connection:
@@ -126,7 +127,7 @@ type Transport struct {
// * the context returned from Connection.Context
// * the context returned from SendStream.Context
// It is not used for dialed connections.
ConnContext func(context.Context) context.Context
ConnContext func(context.Context, *ClientInfo) (context.Context, error)
// A Tracer traces events that don't belong to a single QUIC connection.
// Tracer.Close is called when the transport is closed.