pass a context to Transport.ConnContext (#4536)

* pass a context to Transport.ConnContext

This context is cancelled when the QUIC connection is closed, or when
the QUIC handshake fails. This allows the application to easily build
and garbage collect a map of active connections.

* correctly handle fresh contexts returned from ConnContext
This commit is contained in:
Marten Seemann
2024-06-05 12:37:18 +08:00
committed by GitHub
parent 07acaad2f7
commit 44e0147f2e
7 changed files with 114 additions and 35 deletions

View File

@@ -222,6 +222,7 @@ var (
var newConnection = func(
ctx context.Context,
ctxCancel context.CancelCauseFunc,
conn sendConn,
runner connRunner,
origDestConnID protocol.ConnectionID,
@@ -240,6 +241,8 @@ var newConnection = func(
v protocol.Version,
) quicConn {
s := &connection{
ctx: ctx,
ctxCancel: ctxCancel,
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
@@ -273,7 +276,6 @@ var newConnection = func(
s.queueControlFrame,
connIDGenerator,
)
s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
0,
@@ -499,9 +501,7 @@ func (s *connection) preSetup() {
// run the connection main loop
func (s *connection) run() error {
var closeErr closeError
defer func() {
s.ctxCancel(closeErr.err)
}()
defer func() { s.ctxCancel(closeErr.err) }()
s.timer = *newTimer()

View File

@@ -115,8 +115,10 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().SentTransportParameters(gomock.Any())
tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes()
tracer.EXPECT().UpdatedCongestionState(gomock.Any())
ctx, cancel := context.WithCancelCause(context.Background())
conn = newConnection(
context.Background(),
ctx,
cancel,
mconn,
connRunner,
protocol.ConnectionID{},

View File

@@ -124,8 +124,6 @@ var _ = Describe("Handshake tests", func() {
})
It("uses the context everywhere, on the server side", func() {
//nolint:staticcheck
serverCtx := context.WithValue(context.Background(), "foo", "bar")
tlsGetConfigForClientContextChan := make(chan context.Context, 1)
tlsGetCertificateContextChan := make(chan context.Context, 1)
tracerContextChan := make(chan context.Context, 1)
@@ -135,8 +133,11 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func() context.Context { return serverCtx },
Conn: conn,
ConnContext: func(ctx context.Context) context.Context {
//nolint:staticcheck
return context.WithValue(ctx, "foo", "bar")
},
}
defer tr.Close()
server, err := tr.Listen(
@@ -173,7 +174,7 @@ var _ = Describe("Handshake tests", func() {
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
streamContextChan <- str.Context()
str.Close()
str.Write([]byte{1, 2, 3})
}
}()
@@ -184,21 +185,63 @@ var _ = Describe("Handshake tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
_, err = c.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
c.CloseWithError(1337, "bye")
checkContext := func(c <-chan context.Context) {
checkContext := func(c <-chan context.Context, checkCancellationCause bool) {
var ctx context.Context
Eventually(c).Should(Receive(&ctx))
EventuallyWithOffset(1, c).Should(Receive(&ctx))
val := ctx.Value("foo")
ExpectWithOffset(1, val).ToNot(BeNil())
v := val.(string)
ExpectWithOffset(1, v).To(Equal("bar"))
EventuallyWithOffset(1, ctx.Done).Should(BeClosed())
if !checkCancellationCause {
return
}
ctxErr := context.Cause(ctx)
var appErr *quic.ApplicationError
ExpectWithOffset(1, errors.As(ctxErr, &appErr)).To(BeTrue())
ExpectWithOffset(1, appErr.ErrorCode).To(BeEquivalentTo(1337))
}
checkContext(connContextChan)
checkContext(tlsGetConfigForClientContextChan)
checkContext(tlsGetCertificateContextChan)
checkContext(tracerContextChan)
checkContext(streamContextChan)
checkContext(connContextChan, true)
checkContext(tracerContextChan, true)
checkContext(streamContextChan, true)
// crypto/tls cancels the context when the TLS handshake completes.
checkContext(tlsGetConfigForClientContextChan, false)
checkContext(tlsGetCertificateContextChan, false)
})
It("correctly handles a fresh context returned from ConnContext", func() {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func(ctx context.Context) context.Context { return context.Background() },
}
server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
conn, err := server.Accept(context.Background())
if err != nil {
return
}
Eventually(conn.Context().Done).Should(BeClosed())
}()
c, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
c.CloseWithError(1337, "bye")
})
It("uses the context everywhere, on the client side", func() {
@@ -227,31 +270,39 @@ var _ = Describe("Handshake tests", func() {
}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
cancel()
// make sure the connection context is not cancelled (even though derived from the ctx passed to Dial)
Expect(ctx.Done()).ToNot(Receive())
checkContext := func(ctx context.Context) {
checkContext := func(ctx context.Context, checkCancellationCause bool) {
val := ctx.Value("foo")
ExpectWithOffset(2, val).ToNot(BeNil())
v := val.(string)
ExpectWithOffset(2, v).To(Equal("bar"))
if !checkCancellationCause {
return
}
ctxErr := context.Cause(ctx)
var appErr *quic.ApplicationError
ExpectWithOffset(1, errors.As(ctxErr, &appErr)).To(BeTrue())
ExpectWithOffset(1, appErr.ErrorCode).To(BeEquivalentTo(1337))
}
checkContextFromChan := func(c <-chan context.Context) {
checkContextFromChan := func(c <-chan context.Context, checkCancellationCause bool) {
var ctx context.Context
EventuallyWithOffset(1, c).Should(Receive(&ctx))
checkContext(ctx)
checkContext(ctx, checkCancellationCause)
}
checkContext(conn.Context())
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
checkContext(str.Context())
str.Close()
checkContextFromChan(tlsContextChan)
checkContextFromChan(tracerContextChan)
conn.CloseWithError(1337, "bye")
checkContext(conn.Context(), true)
checkContext(str.Context(), true)
// crypto/tls cancels the context when the TLS handshake completes.
checkContextFromChan(tlsContextChan, false)
checkContextFromChan(tracerContextChan, false)
})
Context("using different cipher suites", func() {

View File

@@ -600,9 +600,9 @@ var _ = Describe("HTTP tests", func() {
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func() context.Context {
ConnContext: func(ctx context.Context) context.Context {
//nolint:staticcheck
return context.WithValue(context.Background(), "foo", "bar")
return context.WithValue(ctx, "foo", "bar")
},
}
defer tr.Close()

View File

@@ -76,11 +76,12 @@ type baseServer struct {
nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
connContext func() context.Context
connContext func(context.Context) context.Context
// set as a member, so they can be set in the tests
newConn func(
context.Context,
context.CancelCauseFunc,
sendConn,
connRunner,
protocol.ConnectionID, /* original dest connection ID */
@@ -233,7 +234,7 @@ func newServer(
conn rawConn,
connHandler packetHandlerManager,
connIDGenerator ConnectionIDGenerator,
connContext func() context.Context,
connContext func(context.Context) context.Context,
tlsConf *tls.Config,
config *Config,
tracer *logging.Tracer,
@@ -635,14 +636,24 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
var conn quicConn
var ctx context.Context
var cancel context.CancelCauseFunc
ctx, cancel1 := context.WithCancelCause(context.Background())
if s.connContext != nil {
ctx = s.connContext()
ctx = s.connContext(ctx)
if ctx == nil {
panic("quic: ConnContext returned nil")
}
// There's no guarantee that the application returns a context
// that's derived from the context we passed into ConnContext.
// We need to make sure that both contexts are cancelled.
var cancel2 context.CancelCauseFunc
ctx, cancel2 = context.WithCancelCause(ctx)
cancel = func(cause error) {
cancel1(cause)
cancel2(cause)
}
} else {
ctx = context.Background()
cancel = cancel1
}
ctx = context.WithValue(ctx, ConnectionTracingKey, nextConnTracingID())
var tracer *logging.ConnectionTracer
@@ -661,6 +672,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
s.logger.Debugf("Changing connection ID to %s.", connID)
conn = s.newConn(
ctx,
cancel,
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.connHandler,
origDestConnID,

View File

@@ -288,6 +288,7 @@ var _ = Describe("Server", func() {
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
origDestConnID protocol.ConnectionID,
@@ -490,6 +491,7 @@ var _ = Describe("Server", func() {
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
origDestConnID protocol.ConnectionID,
@@ -558,6 +560,7 @@ var _ = Describe("Server", func() {
var counter atomic.Uint32
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
@@ -613,6 +616,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
@@ -662,6 +666,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
@@ -723,6 +728,7 @@ var _ = Describe("Server", func() {
It("decodes the token from the token field", func() {
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
@@ -956,6 +962,7 @@ var _ = Describe("Server", func() {
destroyed := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
@@ -1023,6 +1030,7 @@ var _ = Describe("Server", func() {
handshakeChan := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
@@ -1093,6 +1101,7 @@ var _ = Describe("Server", func() {
handshakeChan := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
@@ -1163,6 +1172,7 @@ var _ = Describe("Server", func() {
ready := make(chan struct{})
serv.baseServer.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
@@ -1204,6 +1214,7 @@ var _ = Describe("Server", func() {
wg.Add(protocol.MaxAcceptQueueSize)
serv.baseServer.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
@@ -1263,6 +1274,7 @@ var _ = Describe("Server", func() {
conn := NewMockQUICConn(mockCtrl)
serv.baseServer.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
@@ -1385,6 +1397,7 @@ var _ = Describe("Server", func() {
called := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,

View File

@@ -90,6 +90,7 @@ type Transport struct {
VerifySourceAddress func(net.Addr) bool
// ConnContext is called when the server accepts a new connection.
// The context is closed when the connection is closed, or when the handshake fails for any reason.
// The context returned from the callback is used to derive every other context used during the
// lifetime of the connection:
// * the context passed to crypto/tls (and used on the tls.ClientHelloInfo)
@@ -97,7 +98,7 @@ type Transport struct {
// * the context returned from Connection.Context
// * the context returned from SendStream.Context
// It is not used for dialed connections.
ConnContext func() context.Context
ConnContext func(context.Context) context.Context
// A Tracer traces events that don't belong to a single QUIC connection.
// Tracer.Close is called when the transport is closed.