pass a context to Transport.ConnContext (#4536)

* pass a context to Transport.ConnContext

This context is cancelled when the QUIC connection is closed, or when
the QUIC handshake fails. This allows the application to easily build
and garbage collect a map of active connections.

* correctly handle fresh contexts returned from ConnContext
This commit is contained in:
Marten Seemann
2024-06-05 12:37:18 +08:00
committed by GitHub
parent 07acaad2f7
commit 44e0147f2e
7 changed files with 114 additions and 35 deletions

View File

@@ -124,8 +124,6 @@ var _ = Describe("Handshake tests", func() {
})
It("uses the context everywhere, on the server side", func() {
//nolint:staticcheck
serverCtx := context.WithValue(context.Background(), "foo", "bar")
tlsGetConfigForClientContextChan := make(chan context.Context, 1)
tlsGetCertificateContextChan := make(chan context.Context, 1)
tracerContextChan := make(chan context.Context, 1)
@@ -135,8 +133,11 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func() context.Context { return serverCtx },
Conn: conn,
ConnContext: func(ctx context.Context) context.Context {
//nolint:staticcheck
return context.WithValue(ctx, "foo", "bar")
},
}
defer tr.Close()
server, err := tr.Listen(
@@ -173,7 +174,7 @@ var _ = Describe("Handshake tests", func() {
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
streamContextChan <- str.Context()
str.Close()
str.Write([]byte{1, 2, 3})
}
}()
@@ -184,21 +185,63 @@ var _ = Describe("Handshake tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
_, err = c.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
c.CloseWithError(1337, "bye")
checkContext := func(c <-chan context.Context) {
checkContext := func(c <-chan context.Context, checkCancellationCause bool) {
var ctx context.Context
Eventually(c).Should(Receive(&ctx))
EventuallyWithOffset(1, c).Should(Receive(&ctx))
val := ctx.Value("foo")
ExpectWithOffset(1, val).ToNot(BeNil())
v := val.(string)
ExpectWithOffset(1, v).To(Equal("bar"))
EventuallyWithOffset(1, ctx.Done).Should(BeClosed())
if !checkCancellationCause {
return
}
ctxErr := context.Cause(ctx)
var appErr *quic.ApplicationError
ExpectWithOffset(1, errors.As(ctxErr, &appErr)).To(BeTrue())
ExpectWithOffset(1, appErr.ErrorCode).To(BeEquivalentTo(1337))
}
checkContext(connContextChan)
checkContext(tlsGetConfigForClientContextChan)
checkContext(tlsGetCertificateContextChan)
checkContext(tracerContextChan)
checkContext(streamContextChan)
checkContext(connContextChan, true)
checkContext(tracerContextChan, true)
checkContext(streamContextChan, true)
// crypto/tls cancels the context when the TLS handshake completes.
checkContext(tlsGetConfigForClientContextChan, false)
checkContext(tlsGetCertificateContextChan, false)
})
It("correctly handles a fresh context returned from ConnContext", func() {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func(ctx context.Context) context.Context { return context.Background() },
}
server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
conn, err := server.Accept(context.Background())
if err != nil {
return
}
Eventually(conn.Context().Done).Should(BeClosed())
}()
c, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
c.CloseWithError(1337, "bye")
})
It("uses the context everywhere, on the client side", func() {
@@ -227,31 +270,39 @@ var _ = Describe("Handshake tests", func() {
}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
cancel()
// make sure the connection context is not cancelled (even though derived from the ctx passed to Dial)
Expect(ctx.Done()).ToNot(Receive())
checkContext := func(ctx context.Context) {
checkContext := func(ctx context.Context, checkCancellationCause bool) {
val := ctx.Value("foo")
ExpectWithOffset(2, val).ToNot(BeNil())
v := val.(string)
ExpectWithOffset(2, v).To(Equal("bar"))
if !checkCancellationCause {
return
}
ctxErr := context.Cause(ctx)
var appErr *quic.ApplicationError
ExpectWithOffset(1, errors.As(ctxErr, &appErr)).To(BeTrue())
ExpectWithOffset(1, appErr.ErrorCode).To(BeEquivalentTo(1337))
}
checkContextFromChan := func(c <-chan context.Context) {
checkContextFromChan := func(c <-chan context.Context, checkCancellationCause bool) {
var ctx context.Context
EventuallyWithOffset(1, c).Should(Receive(&ctx))
checkContext(ctx)
checkContext(ctx, checkCancellationCause)
}
checkContext(conn.Context())
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
checkContext(str.Context())
str.Close()
checkContextFromChan(tlsContextChan)
checkContextFromChan(tracerContextChan)
conn.CloseWithError(1337, "bye")
checkContext(conn.Context(), true)
checkContext(str.Context(), true)
// crypto/tls cancels the context when the TLS handshake completes.
checkContextFromChan(tlsContextChan, false)
checkContextFromChan(tracerContextChan, false)
})
Context("using different cipher suites", func() {