introduce Transport.ConnContext, use client's context on the connection (#4507)

* introduce Transport.ConnContext, use client's context on the connection

* panic if ConnContext returns nil
This commit is contained in:
Marten Seemann
2024-05-27 12:30:19 +08:00
committed by GitHub
parent e2fbf3cdcd
commit 0d1e27d77c
15 changed files with 260 additions and 97 deletions

View File

@@ -76,8 +76,11 @@ type baseServer struct {
nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
connContext func() context.Context
// set as a member, so they can be set in the tests
newConn func(
context.Context,
sendConn,
connRunner,
protocol.ConnectionID, /* original dest connection ID */
@@ -92,7 +95,6 @@ type baseServer struct {
*handshake.TokenGenerator,
bool, /* client address validated by an address validation token */
*logging.ConnectionTracer,
ConnectionTracingID,
utils.Logger,
protocol.Version,
) quicConn
@@ -231,6 +233,7 @@ func newServer(
conn rawConn,
connHandler packetHandlerManager,
connIDGenerator ConnectionIDGenerator,
connContext func() context.Context,
tlsConf *tls.Config,
config *Config,
tracer *logging.Tracer,
@@ -243,6 +246,7 @@ func newServer(
) *baseServer {
s := &baseServer{
conn: conn,
connContext: connContext,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
@@ -631,7 +635,16 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
var conn quicConn
tracingID := nextConnTracingID()
var ctx context.Context
if s.connContext != nil {
ctx = s.connContext()
if ctx == nil {
panic("quic: ConnContext returned nil")
}
} else {
ctx = context.Background()
}
ctx = context.WithValue(ctx, ConnectionTracingKey, nextConnTracingID())
var tracer *logging.ConnectionTracer
if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
@@ -639,7 +652,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
tracer = config.Tracer(ctx, protocol.PerspectiveServer, connID)
}
connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
@@ -647,6 +660,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
s.logger.Debugf("Changing connection ID to %s.", connID)
conn = s.newConn(
ctx,
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.connHandler,
origDestConnID,
@@ -661,7 +675,6 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
s.tokenGenerator,
clientAddrVerified,
tracer,
tracingID,
s.logger,
hdr.Version,
)