diff --git a/integrationtests/self/handshake_context_test.go b/integrationtests/self/handshake_context_test.go index 71e7c6264..80da977ce 100644 --- a/integrationtests/self/handshake_context_test.go +++ b/integrationtests/self/handshake_context_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "errors" + "net" "testing" "time" @@ -50,8 +51,8 @@ func TestConnContextOnServerSide(t *testing.T) { tr := &quic.Transport{ Conn: newUDPConnLocalhost(t), - ConnContext: func(ctx context.Context) context.Context { - return context.WithValue(ctx, "foo", "bar") //nolint:staticcheck + ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) { + return context.WithValue(ctx, "foo", "bar"), nil }, } defer tr.Close() @@ -132,11 +133,57 @@ func TestConnContextOnServerSide(t *testing.T) { checkContext(tlsGetCertificateContextChan, false) } +func TestConnContextRejection(t *testing.T) { + t.Run("rejecting", func(t *testing.T) { + testConnContextRejection(t, true) + }) + t.Run("not rejecting", func(t *testing.T) { + testConnContextRejection(t, false) + }) +} + +func testConnContextRejection(t *testing.T, reject bool) { + tr := &quic.Transport{ + Conn: newUDPConnLocalhost(t), + ConnContext: func(ctx context.Context, ci *quic.ClientInfo) (context.Context, error) { + if reject { + return nil, errors.New("rejecting connection") + } + return context.WithValue(ctx, "addr", ci.RemoteAddr), nil + }, + } + defer tr.Close() + + server, err := tr.Listen( + getTLSConfig(), + getQuicConfig(nil), + ) + require.NoError(t, err) + defer server.Close() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + pc := newUDPConnLocalhost(t) + c, err := quic.Dial(ctx, pc, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) + if reject { + require.ErrorIs(t, err, &quic.TransportError{Remote: true, ErrorCode: quic.ConnectionRefused}) + return + } + require.NoError(t, err) + defer c.CloseWithError(0, "") + + conn, err := server.Accept(ctx) + require.NoError(t, err) + require.Equal(t, pc.LocalAddr().String(), conn.Context().Value("addr").(net.Addr).String()) + conn.CloseWithError(0, "") +} + // Users are not supposed to return a fresh context from ConnContext, but we should handle it gracefully. func TestConnContextFreshContext(t *testing.T) { tr := &quic.Transport{ - Conn: newUDPConnLocalhost(t), - ConnContext: func(ctx context.Context) context.Context { return context.Background() }, + Conn: newUDPConnLocalhost(t), + ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) { + return context.Background(), nil + }, } defer tr.Close() server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil)) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 875471ef8..09869fd20 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -614,8 +614,8 @@ func TestHTTPContextFromQUIC(t *testing.T) { conn := newUDPConnLocalhost(t) tr := &quic.Transport{ Conn: conn, - ConnContext: func(ctx context.Context) context.Context { - return context.WithValue(ctx, "foo", "bar") //nolint:staticcheck + ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) { + return context.WithValue(ctx, "foo", "bar"), nil }, } defer tr.Close() diff --git a/server.go b/server.go index e328521df..10b67a7ef 100644 --- a/server.go +++ b/server.go @@ -74,7 +74,7 @@ type baseServer struct { nextZeroRTTCleanup time.Time zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true - connContext func(context.Context) context.Context + connContext func(context.Context, *ClientInfo) (context.Context, error) // set as a member, so they can be set in the tests newConn func( @@ -240,7 +240,7 @@ func newServer( tr *packetHandlerMap, connIDGenerator ConnectionIDGenerator, statelessResetter *statelessResetter, - connContext func(context.Context) context.Context, + connContext func(context.Context, *ClientInfo) (context.Context, error), tlsConf *tls.Config, config *Config, tracer *logging.Tracer, @@ -638,20 +638,15 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error } config := s.config + clientInfo := &ClientInfo{ + RemoteAddr: p.remoteAddr, + AddrVerified: clientAddrVerified, + } if s.config.GetConfigForClient != nil { - conf, err := s.config.GetConfigForClient(&ClientInfo{ - RemoteAddr: p.remoteAddr, - AddrVerified: clientAddrVerified, - }) + conf, err := s.config.GetConfigForClient(clientInfo) if err != nil { s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback") - delete(s.zeroRTTQueues, hdr.DestConnectionID) - select { - case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}: - default: - // drop packet if we can't send out the CONNECTION_REFUSED fast enough - p.buffer.Release() - } + s.refuseNewConn(p, hdr) return nil } config = populateConfig(conf) @@ -661,7 +656,14 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error var cancel context.CancelCauseFunc ctx, cancel1 := context.WithCancelCause(context.Background()) if s.connContext != nil { - ctx = s.connContext(ctx) + var err error + ctx, err = s.connContext(ctx, clientInfo) + if err != nil { + cancel1(err) + s.logger.Debugf("Rejecting new connection due to ConnContext callback: %s", err) + s.refuseNewConn(p, hdr) + return nil + } if ctx == nil { panic("quic: ConnContext returned nil") } @@ -739,6 +741,16 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error return nil } +func (s *baseServer) refuseNewConn(p receivedPacket, hdr *wire.Header) { + delete(s.zeroRTTQueues, hdr.DestConnectionID) + select { + case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}: + default: + // drop packet if we can't send out the CONNECTION_REFUSED fast enough + p.buffer.Release() + } +} + func (s *baseServer) handleNewConn(conn quicConn) { if s.acceptEarlyConns { // wait until the early connection is ready, the handshake fails, or the server is closed diff --git a/server_test.go b/server_test.go index 38cdfe891..b6dad516c 100644 --- a/server_test.go +++ b/server_test.go @@ -68,7 +68,7 @@ func newTestServer(t *testing.T, serverOpts *serverOpts) *testServer { (*packetHandlerMap)(tr), &protocol.DefaultConnectionIDGenerator{}, &statelessResetter{}, - func(ctx context.Context) context.Context { return ctx }, + func(ctx context.Context, _ *ClientInfo) (context.Context, error) { return ctx, nil }, &tls.Config{}, config, serverOpts.tracer, diff --git a/transport.go b/transport.go index 65a2d8034..525adb3f1 100644 --- a/transport.go +++ b/transport.go @@ -117,7 +117,8 @@ type Transport struct { // implementation of this callback (negating its return value). VerifySourceAddress func(net.Addr) bool - // ConnContext is called when the server accepts a new connection. + // ConnContext is called when the server accepts a new connection. To reject a connection return + // a non-nil error. // The context is closed when the connection is closed, or when the handshake fails for any reason. // The context returned from the callback is used to derive every other context used during the // lifetime of the connection: @@ -126,7 +127,7 @@ type Transport struct { // * the context returned from Connection.Context // * the context returned from SendStream.Context // It is not used for dialed connections. - ConnContext func(context.Context) context.Context + ConnContext func(context.Context, *ClientInfo) (context.Context, error) // A Tracer traces events that don't belong to a single QUIC connection. // Tracer.Close is called when the transport is closed.