forked from quic-go/quic-go
pass a context to Transport.ConnContext (#4536)
* pass a context to Transport.ConnContext This context is cancelled when the QUIC connection is closed, or when the QUIC handshake fails. This allows the application to easily build and garbage collect a map of active connections. * correctly handle fresh contexts returned from ConnContext
This commit is contained in:
@@ -124,8 +124,6 @@ var _ = Describe("Handshake tests", func() {
|
||||
})
|
||||
|
||||
It("uses the context everywhere, on the server side", func() {
|
||||
//nolint:staticcheck
|
||||
serverCtx := context.WithValue(context.Background(), "foo", "bar")
|
||||
tlsGetConfigForClientContextChan := make(chan context.Context, 1)
|
||||
tlsGetCertificateContextChan := make(chan context.Context, 1)
|
||||
tracerContextChan := make(chan context.Context, 1)
|
||||
@@ -135,8 +133,11 @@ var _ = Describe("Handshake tests", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnContext: func() context.Context { return serverCtx },
|
||||
Conn: conn,
|
||||
ConnContext: func(ctx context.Context) context.Context {
|
||||
//nolint:staticcheck
|
||||
return context.WithValue(ctx, "foo", "bar")
|
||||
},
|
||||
}
|
||||
defer tr.Close()
|
||||
server, err := tr.Listen(
|
||||
@@ -173,7 +174,7 @@ var _ = Describe("Handshake tests", func() {
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
streamContextChan <- str.Context()
|
||||
str.Close()
|
||||
str.Write([]byte{1, 2, 3})
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -184,21 +185,63 @@ var _ = Describe("Handshake tests", func() {
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer c.CloseWithError(0, "")
|
||||
_, err = c.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
c.CloseWithError(1337, "bye")
|
||||
|
||||
checkContext := func(c <-chan context.Context) {
|
||||
checkContext := func(c <-chan context.Context, checkCancellationCause bool) {
|
||||
var ctx context.Context
|
||||
Eventually(c).Should(Receive(&ctx))
|
||||
EventuallyWithOffset(1, c).Should(Receive(&ctx))
|
||||
val := ctx.Value("foo")
|
||||
ExpectWithOffset(1, val).ToNot(BeNil())
|
||||
v := val.(string)
|
||||
ExpectWithOffset(1, v).To(Equal("bar"))
|
||||
EventuallyWithOffset(1, ctx.Done).Should(BeClosed())
|
||||
if !checkCancellationCause {
|
||||
return
|
||||
}
|
||||
ctxErr := context.Cause(ctx)
|
||||
var appErr *quic.ApplicationError
|
||||
ExpectWithOffset(1, errors.As(ctxErr, &appErr)).To(BeTrue())
|
||||
ExpectWithOffset(1, appErr.ErrorCode).To(BeEquivalentTo(1337))
|
||||
}
|
||||
checkContext(connContextChan)
|
||||
checkContext(tlsGetConfigForClientContextChan)
|
||||
checkContext(tlsGetCertificateContextChan)
|
||||
checkContext(tracerContextChan)
|
||||
checkContext(streamContextChan)
|
||||
checkContext(connContextChan, true)
|
||||
checkContext(tracerContextChan, true)
|
||||
checkContext(streamContextChan, true)
|
||||
// crypto/tls cancels the context when the TLS handshake completes.
|
||||
checkContext(tlsGetConfigForClientContextChan, false)
|
||||
checkContext(tlsGetCertificateContextChan, false)
|
||||
})
|
||||
|
||||
It("correctly handles a fresh context returned from ConnContext", func() {
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnContext: func(ctx context.Context) context.Context { return context.Background() },
|
||||
}
|
||||
server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
conn, err := server.Accept(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
Eventually(conn.Context().Done).Should(BeClosed())
|
||||
}()
|
||||
|
||||
c, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
c.CloseWithError(1337, "bye")
|
||||
})
|
||||
|
||||
It("uses the context everywhere, on the client side", func() {
|
||||
@@ -227,31 +270,39 @@ var _ = Describe("Handshake tests", func() {
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
cancel()
|
||||
// make sure the connection context is not cancelled (even though derived from the ctx passed to Dial)
|
||||
Expect(ctx.Done()).ToNot(Receive())
|
||||
|
||||
checkContext := func(ctx context.Context) {
|
||||
checkContext := func(ctx context.Context, checkCancellationCause bool) {
|
||||
val := ctx.Value("foo")
|
||||
ExpectWithOffset(2, val).ToNot(BeNil())
|
||||
v := val.(string)
|
||||
ExpectWithOffset(2, v).To(Equal("bar"))
|
||||
if !checkCancellationCause {
|
||||
return
|
||||
}
|
||||
ctxErr := context.Cause(ctx)
|
||||
var appErr *quic.ApplicationError
|
||||
ExpectWithOffset(1, errors.As(ctxErr, &appErr)).To(BeTrue())
|
||||
ExpectWithOffset(1, appErr.ErrorCode).To(BeEquivalentTo(1337))
|
||||
}
|
||||
|
||||
checkContextFromChan := func(c <-chan context.Context) {
|
||||
checkContextFromChan := func(c <-chan context.Context, checkCancellationCause bool) {
|
||||
var ctx context.Context
|
||||
EventuallyWithOffset(1, c).Should(Receive(&ctx))
|
||||
checkContext(ctx)
|
||||
checkContext(ctx, checkCancellationCause)
|
||||
}
|
||||
|
||||
checkContext(conn.Context())
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
checkContext(str.Context())
|
||||
str.Close()
|
||||
checkContextFromChan(tlsContextChan)
|
||||
checkContextFromChan(tracerContextChan)
|
||||
conn.CloseWithError(1337, "bye")
|
||||
|
||||
checkContext(conn.Context(), true)
|
||||
checkContext(str.Context(), true)
|
||||
// crypto/tls cancels the context when the TLS handshake completes.
|
||||
checkContextFromChan(tlsContextChan, false)
|
||||
checkContextFromChan(tracerContextChan, false)
|
||||
})
|
||||
|
||||
Context("using different cipher suites", func() {
|
||||
|
||||
Reference in New Issue
Block a user