introduce Transport.ConnContext, use client's context on the connection (#4507)

* introduce Transport.ConnContext, use client's context on the connection

* panic if ConnContext returns nil
This commit is contained in:
Marten Seemann
2024-05-27 12:30:19 +08:00
committed by GitHub
parent e2fbf3cdcd
commit 0d1e27d77c
15 changed files with 260 additions and 97 deletions

View File

@@ -14,6 +14,7 @@ import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -122,6 +123,137 @@ var _ = Describe("Handshake tests", func() {
Expect(err).To(MatchError("application cancelled"))
})
It("uses the context everywhere, on the server side", func() {
//nolint:staticcheck
serverCtx := context.WithValue(context.Background(), "foo", "bar")
tlsGetConfigForClientContextChan := make(chan context.Context, 1)
tlsGetCertificateContextChan := make(chan context.Context, 1)
tracerContextChan := make(chan context.Context, 1)
connContextChan := make(chan context.Context, 1)
streamContextChan := make(chan context.Context, 1)
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func() context.Context { return serverCtx },
}
defer tr.Close()
server, err := tr.Listen(
&tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
tlsGetConfigForClientContextChan <- info.Context()
tlsConf := getTLSConfig()
tlsConf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
tlsGetCertificateContextChan <- info.Context()
return &tlsConf.Certificates[0], nil
}
return tlsConf, nil
},
},
getQuicConfig(&quic.Config{
Tracer: func(ctx context.Context, _ logging.Perspective, _ quic.ConnectionID) *logging.ConnectionTracer {
tracerContextChan <- ctx
return nil
},
}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
go func() {
defer GinkgoRecover()
defer close(acceptStopped)
for {
conn, err := server.Accept(context.Background())
if err != nil {
return
}
connContextChan <- conn.Context()
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
streamContextChan <- str.Context()
str.Close()
}
}()
c, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
checkContext := func(c <-chan context.Context) {
var ctx context.Context
Eventually(c).Should(Receive(&ctx))
val := ctx.Value("foo")
ExpectWithOffset(1, val).ToNot(BeNil())
v := val.(string)
ExpectWithOffset(1, v).To(Equal("bar"))
}
checkContext(connContextChan)
checkContext(tlsGetConfigForClientContextChan)
checkContext(tlsGetCertificateContextChan)
checkContext(tracerContextChan)
checkContext(streamContextChan)
})
It("uses the context everywhere, on the client side", func() {
tlsServerConf := getTLSConfig()
tlsServerConf.ClientAuth = tls.RequestClientCert
runServer(tlsServerConf)
tlsContextChan := make(chan context.Context, 1)
tracerContextChan := make(chan context.Context, 1)
tlsConf := getTLSClientConfig()
tlsConf.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
tlsContextChan <- info.Context()
return &tlsServerConf.Certificates[0], nil
}
//nolint:staticcheck
ctx, cancel := context.WithCancel(context.WithValue(context.Background(), "foo", "bar"))
conn, err := quic.DialAddr(
ctx,
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(&quic.Config{
Tracer: func(ctx context.Context, _ logging.Perspective, _ quic.ConnectionID) *logging.ConnectionTracer {
tracerContextChan <- ctx
return nil
},
}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
cancel()
// make sure the connection context is not cancelled (even though derived from the ctx passed to Dial)
Expect(ctx.Done()).ToNot(Receive())
checkContext := func(ctx context.Context) {
val := ctx.Value("foo")
ExpectWithOffset(2, val).ToNot(BeNil())
v := val.(string)
ExpectWithOffset(2, v).To(Equal("bar"))
}
checkContextFromChan := func(c <-chan context.Context) {
var ctx context.Context
EventuallyWithOffset(1, c).Should(Receive(&ctx))
checkContext(ctx)
}
checkContext(conn.Context())
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
checkContext(str.Context())
str.Close()
checkContextFromChan(tlsContextChan)
checkContextFromChan(tracerContextChan)
})
Context("using different cipher suites", func() {
for n, id := range map[string]uint16{
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256,

View File

@@ -186,10 +186,12 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
}
origTracer := conf.Tracer
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
origTracer(ctx, p, connID),
)
tr := origTracer(ctx, p, connID)
qlogger := tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID)
if tr == nil {
return qlogger
}
return logging.NewMultiplexedConnectionTracer(qlogger, tr)
}
return conf
}