pass ClientInfo to Transport.ConnContext, allow rejection of handshake (#5122)

* pass *ClientInfo to ConnContext

This allows users to set some state related to the client's remote
address in `ConnContext`.

This also refuses connection early if the returned context is cancelled.

* review comments

* better error assertion

---------

Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
sukun
2025-05-16 15:33:05 +05:30
committed by GitHub
parent 77c401dcbb
commit eb8547ca73
5 changed files with 83 additions and 23 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"errors"
"net"
"testing"
"time"
@@ -50,8 +51,8 @@ func TestConnContextOnServerSide(t *testing.T) {
tr := &quic.Transport{
Conn: newUDPConnLocalhost(t),
ConnContext: func(ctx context.Context) context.Context {
return context.WithValue(ctx, "foo", "bar") //nolint:staticcheck
ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) {
return context.WithValue(ctx, "foo", "bar"), nil
},
}
defer tr.Close()
@@ -132,11 +133,57 @@ func TestConnContextOnServerSide(t *testing.T) {
checkContext(tlsGetCertificateContextChan, false)
}
func TestConnContextRejection(t *testing.T) {
t.Run("rejecting", func(t *testing.T) {
testConnContextRejection(t, true)
})
t.Run("not rejecting", func(t *testing.T) {
testConnContextRejection(t, false)
})
}
func testConnContextRejection(t *testing.T, reject bool) {
tr := &quic.Transport{
Conn: newUDPConnLocalhost(t),
ConnContext: func(ctx context.Context, ci *quic.ClientInfo) (context.Context, error) {
if reject {
return nil, errors.New("rejecting connection")
}
return context.WithValue(ctx, "addr", ci.RemoteAddr), nil
},
}
defer tr.Close()
server, err := tr.Listen(
getTLSConfig(),
getQuicConfig(nil),
)
require.NoError(t, err)
defer server.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
pc := newUDPConnLocalhost(t)
c, err := quic.Dial(ctx, pc, server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
if reject {
require.ErrorIs(t, err, &quic.TransportError{Remote: true, ErrorCode: quic.ConnectionRefused})
return
}
require.NoError(t, err)
defer c.CloseWithError(0, "")
conn, err := server.Accept(ctx)
require.NoError(t, err)
require.Equal(t, pc.LocalAddr().String(), conn.Context().Value("addr").(net.Addr).String())
conn.CloseWithError(0, "")
}
// Users are not supposed to return a fresh context from ConnContext, but we should handle it gracefully.
func TestConnContextFreshContext(t *testing.T) {
tr := &quic.Transport{
Conn: newUDPConnLocalhost(t),
ConnContext: func(ctx context.Context) context.Context { return context.Background() },
Conn: newUDPConnLocalhost(t),
ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) {
return context.Background(), nil
},
}
defer tr.Close()
server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))

View File

@@ -614,8 +614,8 @@ func TestHTTPContextFromQUIC(t *testing.T) {
conn := newUDPConnLocalhost(t)
tr := &quic.Transport{
Conn: conn,
ConnContext: func(ctx context.Context) context.Context {
return context.WithValue(ctx, "foo", "bar") //nolint:staticcheck
ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) {
return context.WithValue(ctx, "foo", "bar"), nil
},
}
defer tr.Close()