forked from quic-go/quic-go
pass ClientInfo to Transport.ConnContext, allow rejection of handshake (#5122)
* pass *ClientInfo to ConnContext This allows users to set some state related to the client's remote address in `ConnContext`. This also refuses connection early if the returned context is cancelled. * review comments * better error assertion --------- Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -50,8 +51,8 @@ func TestConnContextOnServerSide(t *testing.T) {
|
||||
|
||||
tr := &quic.Transport{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ConnContext: func(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, "foo", "bar") //nolint:staticcheck
|
||||
ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) {
|
||||
return context.WithValue(ctx, "foo", "bar"), nil
|
||||
},
|
||||
}
|
||||
defer tr.Close()
|
||||
@@ -132,11 +133,57 @@ func TestConnContextOnServerSide(t *testing.T) {
|
||||
checkContext(tlsGetCertificateContextChan, false)
|
||||
}
|
||||
|
||||
func TestConnContextRejection(t *testing.T) {
|
||||
t.Run("rejecting", func(t *testing.T) {
|
||||
testConnContextRejection(t, true)
|
||||
})
|
||||
t.Run("not rejecting", func(t *testing.T) {
|
||||
testConnContextRejection(t, false)
|
||||
})
|
||||
}
|
||||
|
||||
func testConnContextRejection(t *testing.T, reject bool) {
|
||||
tr := &quic.Transport{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ConnContext: func(ctx context.Context, ci *quic.ClientInfo) (context.Context, error) {
|
||||
if reject {
|
||||
return nil, errors.New("rejecting connection")
|
||||
}
|
||||
return context.WithValue(ctx, "addr", ci.RemoteAddr), nil
|
||||
},
|
||||
}
|
||||
defer tr.Close()
|
||||
|
||||
server, err := tr.Listen(
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
pc := newUDPConnLocalhost(t)
|
||||
c, err := quic.Dial(ctx, pc, server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
||||
if reject {
|
||||
require.ErrorIs(t, err, &quic.TransportError{Remote: true, ErrorCode: quic.ConnectionRefused})
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
defer c.CloseWithError(0, "")
|
||||
|
||||
conn, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pc.LocalAddr().String(), conn.Context().Value("addr").(net.Addr).String())
|
||||
conn.CloseWithError(0, "")
|
||||
}
|
||||
|
||||
// Users are not supposed to return a fresh context from ConnContext, but we should handle it gracefully.
|
||||
func TestConnContextFreshContext(t *testing.T) {
|
||||
tr := &quic.Transport{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ConnContext: func(ctx context.Context) context.Context { return context.Background() },
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) {
|
||||
return context.Background(), nil
|
||||
},
|
||||
}
|
||||
defer tr.Close()
|
||||
server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
|
||||
@@ -614,8 +614,8 @@ func TestHTTPContextFromQUIC(t *testing.T) {
|
||||
conn := newUDPConnLocalhost(t)
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnContext: func(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, "foo", "bar") //nolint:staticcheck
|
||||
ConnContext: func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) {
|
||||
return context.WithValue(ctx, "foo", "bar"), nil
|
||||
},
|
||||
}
|
||||
defer tr.Close()
|
||||
|
||||
40
server.go
40
server.go
@@ -74,7 +74,7 @@ type baseServer struct {
|
||||
nextZeroRTTCleanup time.Time
|
||||
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
|
||||
|
||||
connContext func(context.Context) context.Context
|
||||
connContext func(context.Context, *ClientInfo) (context.Context, error)
|
||||
|
||||
// set as a member, so they can be set in the tests
|
||||
newConn func(
|
||||
@@ -240,7 +240,7 @@ func newServer(
|
||||
tr *packetHandlerMap,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetter *statelessResetter,
|
||||
connContext func(context.Context) context.Context,
|
||||
connContext func(context.Context, *ClientInfo) (context.Context, error),
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
tracer *logging.Tracer,
|
||||
@@ -638,20 +638,15 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
}
|
||||
|
||||
config := s.config
|
||||
clientInfo := &ClientInfo{
|
||||
RemoteAddr: p.remoteAddr,
|
||||
AddrVerified: clientAddrVerified,
|
||||
}
|
||||
if s.config.GetConfigForClient != nil {
|
||||
conf, err := s.config.GetConfigForClient(&ClientInfo{
|
||||
RemoteAddr: p.remoteAddr,
|
||||
AddrVerified: clientAddrVerified,
|
||||
})
|
||||
conf, err := s.config.GetConfigForClient(clientInfo)
|
||||
if err != nil {
|
||||
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
select {
|
||||
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
||||
default:
|
||||
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
|
||||
p.buffer.Release()
|
||||
}
|
||||
s.refuseNewConn(p, hdr)
|
||||
return nil
|
||||
}
|
||||
config = populateConfig(conf)
|
||||
@@ -661,7 +656,14 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
var cancel context.CancelCauseFunc
|
||||
ctx, cancel1 := context.WithCancelCause(context.Background())
|
||||
if s.connContext != nil {
|
||||
ctx = s.connContext(ctx)
|
||||
var err error
|
||||
ctx, err = s.connContext(ctx, clientInfo)
|
||||
if err != nil {
|
||||
cancel1(err)
|
||||
s.logger.Debugf("Rejecting new connection due to ConnContext callback: %s", err)
|
||||
s.refuseNewConn(p, hdr)
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
panic("quic: ConnContext returned nil")
|
||||
}
|
||||
@@ -739,6 +741,16 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *baseServer) refuseNewConn(p receivedPacket, hdr *wire.Header) {
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
select {
|
||||
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
||||
default:
|
||||
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
|
||||
p.buffer.Release()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *baseServer) handleNewConn(conn quicConn) {
|
||||
if s.acceptEarlyConns {
|
||||
// wait until the early connection is ready, the handshake fails, or the server is closed
|
||||
|
||||
@@ -68,7 +68,7 @@ func newTestServer(t *testing.T, serverOpts *serverOpts) *testServer {
|
||||
(*packetHandlerMap)(tr),
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
&statelessResetter{},
|
||||
func(ctx context.Context) context.Context { return ctx },
|
||||
func(ctx context.Context, _ *ClientInfo) (context.Context, error) { return ctx, nil },
|
||||
&tls.Config{},
|
||||
config,
|
||||
serverOpts.tracer,
|
||||
|
||||
@@ -117,7 +117,8 @@ type Transport struct {
|
||||
// implementation of this callback (negating its return value).
|
||||
VerifySourceAddress func(net.Addr) bool
|
||||
|
||||
// ConnContext is called when the server accepts a new connection.
|
||||
// ConnContext is called when the server accepts a new connection. To reject a connection return
|
||||
// a non-nil error.
|
||||
// The context is closed when the connection is closed, or when the handshake fails for any reason.
|
||||
// The context returned from the callback is used to derive every other context used during the
|
||||
// lifetime of the connection:
|
||||
@@ -126,7 +127,7 @@ type Transport struct {
|
||||
// * the context returned from Connection.Context
|
||||
// * the context returned from SendStream.Context
|
||||
// It is not used for dialed connections.
|
||||
ConnContext func(context.Context) context.Context
|
||||
ConnContext func(context.Context, *ClientInfo) (context.Context, error)
|
||||
|
||||
// A Tracer traces events that don't belong to a single QUIC connection.
|
||||
// Tracer.Close is called when the transport is closed.
|
||||
|
||||
Reference in New Issue
Block a user