forked from quic-go/quic-go
pass a context to Transport.ConnContext (#4536)
* pass a context to Transport.ConnContext This context is cancelled when the QUIC connection is closed, or when the QUIC handshake fails. This allows the application to easily build and garbage collect a map of active connections. * correctly handle fresh contexts returned from ConnContext
This commit is contained in:
@@ -222,6 +222,7 @@ var (
|
||||
|
||||
var newConnection = func(
|
||||
ctx context.Context,
|
||||
ctxCancel context.CancelCauseFunc,
|
||||
conn sendConn,
|
||||
runner connRunner,
|
||||
origDestConnID protocol.ConnectionID,
|
||||
@@ -240,6 +241,8 @@ var newConnection = func(
|
||||
v protocol.Version,
|
||||
) quicConn {
|
||||
s := &connection{
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
conn: conn,
|
||||
config: conf,
|
||||
handshakeDestConnID: destConnID,
|
||||
@@ -273,7 +276,6 @@ var newConnection = func(
|
||||
s.queueControlFrame,
|
||||
connIDGenerator,
|
||||
)
|
||||
s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
|
||||
s.preSetup()
|
||||
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
|
||||
0,
|
||||
@@ -499,9 +501,7 @@ func (s *connection) preSetup() {
|
||||
// run the connection main loop
|
||||
func (s *connection) run() error {
|
||||
var closeErr closeError
|
||||
defer func() {
|
||||
s.ctxCancel(closeErr.err)
|
||||
}()
|
||||
defer func() { s.ctxCancel(closeErr.err) }()
|
||||
|
||||
s.timer = *newTimer()
|
||||
|
||||
|
||||
@@ -115,8 +115,10 @@ var _ = Describe("Connection", func() {
|
||||
tracer.EXPECT().SentTransportParameters(gomock.Any())
|
||||
tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
tracer.EXPECT().UpdatedCongestionState(gomock.Any())
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
conn = newConnection(
|
||||
context.Background(),
|
||||
ctx,
|
||||
cancel,
|
||||
mconn,
|
||||
connRunner,
|
||||
protocol.ConnectionID{},
|
||||
|
||||
@@ -124,8 +124,6 @@ var _ = Describe("Handshake tests", func() {
|
||||
})
|
||||
|
||||
It("uses the context everywhere, on the server side", func() {
|
||||
//nolint:staticcheck
|
||||
serverCtx := context.WithValue(context.Background(), "foo", "bar")
|
||||
tlsGetConfigForClientContextChan := make(chan context.Context, 1)
|
||||
tlsGetCertificateContextChan := make(chan context.Context, 1)
|
||||
tracerContextChan := make(chan context.Context, 1)
|
||||
@@ -135,8 +133,11 @@ var _ = Describe("Handshake tests", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnContext: func() context.Context { return serverCtx },
|
||||
Conn: conn,
|
||||
ConnContext: func(ctx context.Context) context.Context {
|
||||
//nolint:staticcheck
|
||||
return context.WithValue(ctx, "foo", "bar")
|
||||
},
|
||||
}
|
||||
defer tr.Close()
|
||||
server, err := tr.Listen(
|
||||
@@ -173,7 +174,7 @@ var _ = Describe("Handshake tests", func() {
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
streamContextChan <- str.Context()
|
||||
str.Close()
|
||||
str.Write([]byte{1, 2, 3})
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -184,21 +185,63 @@ var _ = Describe("Handshake tests", func() {
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer c.CloseWithError(0, "")
|
||||
_, err = c.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
c.CloseWithError(1337, "bye")
|
||||
|
||||
checkContext := func(c <-chan context.Context) {
|
||||
checkContext := func(c <-chan context.Context, checkCancellationCause bool) {
|
||||
var ctx context.Context
|
||||
Eventually(c).Should(Receive(&ctx))
|
||||
EventuallyWithOffset(1, c).Should(Receive(&ctx))
|
||||
val := ctx.Value("foo")
|
||||
ExpectWithOffset(1, val).ToNot(BeNil())
|
||||
v := val.(string)
|
||||
ExpectWithOffset(1, v).To(Equal("bar"))
|
||||
EventuallyWithOffset(1, ctx.Done).Should(BeClosed())
|
||||
if !checkCancellationCause {
|
||||
return
|
||||
}
|
||||
ctxErr := context.Cause(ctx)
|
||||
var appErr *quic.ApplicationError
|
||||
ExpectWithOffset(1, errors.As(ctxErr, &appErr)).To(BeTrue())
|
||||
ExpectWithOffset(1, appErr.ErrorCode).To(BeEquivalentTo(1337))
|
||||
}
|
||||
checkContext(connContextChan)
|
||||
checkContext(tlsGetConfigForClientContextChan)
|
||||
checkContext(tlsGetCertificateContextChan)
|
||||
checkContext(tracerContextChan)
|
||||
checkContext(streamContextChan)
|
||||
checkContext(connContextChan, true)
|
||||
checkContext(tracerContextChan, true)
|
||||
checkContext(streamContextChan, true)
|
||||
// crypto/tls cancels the context when the TLS handshake completes.
|
||||
checkContext(tlsGetConfigForClientContextChan, false)
|
||||
checkContext(tlsGetCertificateContextChan, false)
|
||||
})
|
||||
|
||||
It("correctly handles a fresh context returned from ConnContext", func() {
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnContext: func(ctx context.Context) context.Context { return context.Background() },
|
||||
}
|
||||
server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
conn, err := server.Accept(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
Eventually(conn.Context().Done).Should(BeClosed())
|
||||
}()
|
||||
|
||||
c, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
c.CloseWithError(1337, "bye")
|
||||
})
|
||||
|
||||
It("uses the context everywhere, on the client side", func() {
|
||||
@@ -227,31 +270,39 @@ var _ = Describe("Handshake tests", func() {
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
cancel()
|
||||
// make sure the connection context is not cancelled (even though derived from the ctx passed to Dial)
|
||||
Expect(ctx.Done()).ToNot(Receive())
|
||||
|
||||
checkContext := func(ctx context.Context) {
|
||||
checkContext := func(ctx context.Context, checkCancellationCause bool) {
|
||||
val := ctx.Value("foo")
|
||||
ExpectWithOffset(2, val).ToNot(BeNil())
|
||||
v := val.(string)
|
||||
ExpectWithOffset(2, v).To(Equal("bar"))
|
||||
if !checkCancellationCause {
|
||||
return
|
||||
}
|
||||
ctxErr := context.Cause(ctx)
|
||||
var appErr *quic.ApplicationError
|
||||
ExpectWithOffset(1, errors.As(ctxErr, &appErr)).To(BeTrue())
|
||||
ExpectWithOffset(1, appErr.ErrorCode).To(BeEquivalentTo(1337))
|
||||
}
|
||||
|
||||
checkContextFromChan := func(c <-chan context.Context) {
|
||||
checkContextFromChan := func(c <-chan context.Context, checkCancellationCause bool) {
|
||||
var ctx context.Context
|
||||
EventuallyWithOffset(1, c).Should(Receive(&ctx))
|
||||
checkContext(ctx)
|
||||
checkContext(ctx, checkCancellationCause)
|
||||
}
|
||||
|
||||
checkContext(conn.Context())
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
checkContext(str.Context())
|
||||
str.Close()
|
||||
checkContextFromChan(tlsContextChan)
|
||||
checkContextFromChan(tracerContextChan)
|
||||
conn.CloseWithError(1337, "bye")
|
||||
|
||||
checkContext(conn.Context(), true)
|
||||
checkContext(str.Context(), true)
|
||||
// crypto/tls cancels the context when the TLS handshake completes.
|
||||
checkContextFromChan(tlsContextChan, false)
|
||||
checkContextFromChan(tracerContextChan, false)
|
||||
})
|
||||
|
||||
Context("using different cipher suites", func() {
|
||||
|
||||
@@ -600,9 +600,9 @@ var _ = Describe("HTTP tests", func() {
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnContext: func() context.Context {
|
||||
ConnContext: func(ctx context.Context) context.Context {
|
||||
//nolint:staticcheck
|
||||
return context.WithValue(context.Background(), "foo", "bar")
|
||||
return context.WithValue(ctx, "foo", "bar")
|
||||
},
|
||||
}
|
||||
defer tr.Close()
|
||||
|
||||
22
server.go
22
server.go
@@ -76,11 +76,12 @@ type baseServer struct {
|
||||
nextZeroRTTCleanup time.Time
|
||||
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
|
||||
|
||||
connContext func() context.Context
|
||||
connContext func(context.Context) context.Context
|
||||
|
||||
// set as a member, so they can be set in the tests
|
||||
newConn func(
|
||||
context.Context,
|
||||
context.CancelCauseFunc,
|
||||
sendConn,
|
||||
connRunner,
|
||||
protocol.ConnectionID, /* original dest connection ID */
|
||||
@@ -233,7 +234,7 @@ func newServer(
|
||||
conn rawConn,
|
||||
connHandler packetHandlerManager,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
connContext func() context.Context,
|
||||
connContext func(context.Context) context.Context,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
tracer *logging.Tracer,
|
||||
@@ -635,14 +636,24 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
}
|
||||
|
||||
var conn quicConn
|
||||
var ctx context.Context
|
||||
var cancel context.CancelCauseFunc
|
||||
ctx, cancel1 := context.WithCancelCause(context.Background())
|
||||
if s.connContext != nil {
|
||||
ctx = s.connContext()
|
||||
ctx = s.connContext(ctx)
|
||||
if ctx == nil {
|
||||
panic("quic: ConnContext returned nil")
|
||||
}
|
||||
// There's no guarantee that the application returns a context
|
||||
// that's derived from the context we passed into ConnContext.
|
||||
// We need to make sure that both contexts are cancelled.
|
||||
var cancel2 context.CancelCauseFunc
|
||||
ctx, cancel2 = context.WithCancelCause(ctx)
|
||||
cancel = func(cause error) {
|
||||
cancel1(cause)
|
||||
cancel2(cause)
|
||||
}
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
cancel = cancel1
|
||||
}
|
||||
ctx = context.WithValue(ctx, ConnectionTracingKey, nextConnTracingID())
|
||||
var tracer *logging.ConnectionTracer
|
||||
@@ -661,6 +672,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
s.logger.Debugf("Changing connection ID to %s.", connID)
|
||||
conn = s.newConn(
|
||||
ctx,
|
||||
cancel,
|
||||
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
|
||||
s.connHandler,
|
||||
origDestConnID,
|
||||
|
||||
@@ -288,6 +288,7 @@ var _ = Describe("Server", func() {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
serv.newConn = func(
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
origDestConnID protocol.ConnectionID,
|
||||
@@ -490,6 +491,7 @@ var _ = Describe("Server", func() {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
serv.newConn = func(
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
origDestConnID protocol.ConnectionID,
|
||||
@@ -558,6 +560,7 @@ var _ = Describe("Server", func() {
|
||||
var counter atomic.Uint32
|
||||
serv.newConn = func(
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
@@ -613,6 +616,7 @@ var _ = Describe("Server", func() {
|
||||
done := make(chan struct{})
|
||||
serv.newConn = func(
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
@@ -662,6 +666,7 @@ var _ = Describe("Server", func() {
|
||||
done := make(chan struct{})
|
||||
serv.newConn = func(
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
@@ -723,6 +728,7 @@ var _ = Describe("Server", func() {
|
||||
It("decodes the token from the token field", func() {
|
||||
serv.newConn = func(
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
@@ -956,6 +962,7 @@ var _ = Describe("Server", func() {
|
||||
destroyed := make(chan struct{})
|
||||
serv.newConn = func(
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
@@ -1023,6 +1030,7 @@ var _ = Describe("Server", func() {
|
||||
handshakeChan := make(chan struct{})
|
||||
serv.newConn = func(
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
@@ -1093,6 +1101,7 @@ var _ = Describe("Server", func() {
|
||||
handshakeChan := make(chan struct{})
|
||||
serv.newConn = func(
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
@@ -1163,6 +1172,7 @@ var _ = Describe("Server", func() {
|
||||
ready := make(chan struct{})
|
||||
serv.baseServer.newConn = func(
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
@@ -1204,6 +1214,7 @@ var _ = Describe("Server", func() {
|
||||
wg.Add(protocol.MaxAcceptQueueSize)
|
||||
serv.baseServer.newConn = func(
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
@@ -1263,6 +1274,7 @@ var _ = Describe("Server", func() {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
serv.baseServer.newConn = func(
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
@@ -1385,6 +1397,7 @@ var _ = Describe("Server", func() {
|
||||
called := make(chan struct{})
|
||||
serv.newConn = func(
|
||||
_ context.Context,
|
||||
_ context.CancelCauseFunc,
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
|
||||
@@ -90,6 +90,7 @@ type Transport struct {
|
||||
VerifySourceAddress func(net.Addr) bool
|
||||
|
||||
// ConnContext is called when the server accepts a new connection.
|
||||
// The context is closed when the connection is closed, or when the handshake fails for any reason.
|
||||
// The context returned from the callback is used to derive every other context used during the
|
||||
// lifetime of the connection:
|
||||
// * the context passed to crypto/tls (and used on the tls.ClientHelloInfo)
|
||||
@@ -97,7 +98,7 @@ type Transport struct {
|
||||
// * the context returned from Connection.Context
|
||||
// * the context returned from SendStream.Context
|
||||
// It is not used for dialed connections.
|
||||
ConnContext func() context.Context
|
||||
ConnContext func(context.Context) context.Context
|
||||
|
||||
// A Tracer traces events that don't belong to a single QUIC connection.
|
||||
// Tracer.Close is called when the transport is closed.
|
||||
|
||||
Reference in New Issue
Block a user