forked from quic-go/quic-go
introduce Transport.ConnContext, use client's context on the connection (#4507)
* introduce Transport.ConnContext, use client's context on the connection * panic if ConnContext returns nil
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -122,6 +123,137 @@ var _ = Describe("Handshake tests", func() {
|
||||
Expect(err).To(MatchError("application cancelled"))
|
||||
})
|
||||
|
||||
It("uses the context everywhere, on the server side", func() {
|
||||
//nolint:staticcheck
|
||||
serverCtx := context.WithValue(context.Background(), "foo", "bar")
|
||||
tlsGetConfigForClientContextChan := make(chan context.Context, 1)
|
||||
tlsGetCertificateContextChan := make(chan context.Context, 1)
|
||||
tracerContextChan := make(chan context.Context, 1)
|
||||
connContextChan := make(chan context.Context, 1)
|
||||
streamContextChan := make(chan context.Context, 1)
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnContext: func() context.Context { return serverCtx },
|
||||
}
|
||||
defer tr.Close()
|
||||
server, err := tr.Listen(
|
||||
&tls.Config{
|
||||
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
tlsGetConfigForClientContextChan <- info.Context()
|
||||
tlsConf := getTLSConfig()
|
||||
tlsConf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
tlsGetCertificateContextChan <- info.Context()
|
||||
return &tlsConf.Certificates[0], nil
|
||||
}
|
||||
return tlsConf, nil
|
||||
},
|
||||
},
|
||||
getQuicConfig(&quic.Config{
|
||||
Tracer: func(ctx context.Context, _ logging.Perspective, _ quic.ConnectionID) *logging.ConnectionTracer {
|
||||
tracerContextChan <- ctx
|
||||
return nil
|
||||
},
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(acceptStopped)
|
||||
for {
|
||||
conn, err := server.Accept(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
connContextChan <- conn.Context()
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
streamContextChan <- str.Context()
|
||||
str.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
c, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer c.CloseWithError(0, "")
|
||||
|
||||
checkContext := func(c <-chan context.Context) {
|
||||
var ctx context.Context
|
||||
Eventually(c).Should(Receive(&ctx))
|
||||
val := ctx.Value("foo")
|
||||
ExpectWithOffset(1, val).ToNot(BeNil())
|
||||
v := val.(string)
|
||||
ExpectWithOffset(1, v).To(Equal("bar"))
|
||||
}
|
||||
checkContext(connContextChan)
|
||||
checkContext(tlsGetConfigForClientContextChan)
|
||||
checkContext(tlsGetCertificateContextChan)
|
||||
checkContext(tracerContextChan)
|
||||
checkContext(streamContextChan)
|
||||
})
|
||||
|
||||
It("uses the context everywhere, on the client side", func() {
|
||||
tlsServerConf := getTLSConfig()
|
||||
tlsServerConf.ClientAuth = tls.RequestClientCert
|
||||
runServer(tlsServerConf)
|
||||
|
||||
tlsContextChan := make(chan context.Context, 1)
|
||||
tracerContextChan := make(chan context.Context, 1)
|
||||
tlsConf := getTLSClientConfig()
|
||||
tlsConf.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
tlsContextChan <- info.Context()
|
||||
return &tlsServerConf.Certificates[0], nil
|
||||
}
|
||||
//nolint:staticcheck
|
||||
ctx, cancel := context.WithCancel(context.WithValue(context.Background(), "foo", "bar"))
|
||||
conn, err := quic.DialAddr(
|
||||
ctx,
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Tracer: func(ctx context.Context, _ logging.Perspective, _ quic.ConnectionID) *logging.ConnectionTracer {
|
||||
tracerContextChan <- ctx
|
||||
return nil
|
||||
},
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
cancel()
|
||||
// make sure the connection context is not cancelled (even though derived from the ctx passed to Dial)
|
||||
Expect(ctx.Done()).ToNot(Receive())
|
||||
|
||||
checkContext := func(ctx context.Context) {
|
||||
val := ctx.Value("foo")
|
||||
ExpectWithOffset(2, val).ToNot(BeNil())
|
||||
v := val.(string)
|
||||
ExpectWithOffset(2, v).To(Equal("bar"))
|
||||
}
|
||||
|
||||
checkContextFromChan := func(c <-chan context.Context) {
|
||||
var ctx context.Context
|
||||
EventuallyWithOffset(1, c).Should(Receive(&ctx))
|
||||
checkContext(ctx)
|
||||
}
|
||||
|
||||
checkContext(conn.Context())
|
||||
str, err := conn.OpenUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
checkContext(str.Context())
|
||||
str.Close()
|
||||
checkContextFromChan(tlsContextChan)
|
||||
checkContextFromChan(tracerContextChan)
|
||||
})
|
||||
|
||||
Context("using different cipher suites", func() {
|
||||
for n, id := range map[string]uint16{
|
||||
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256,
|
||||
|
||||
@@ -186,10 +186,12 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
|
||||
}
|
||||
origTracer := conf.Tracer
|
||||
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return logging.NewMultiplexedConnectionTracer(
|
||||
tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
|
||||
origTracer(ctx, p, connID),
|
||||
)
|
||||
tr := origTracer(ctx, p, connID)
|
||||
qlogger := tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID)
|
||||
if tr == nil {
|
||||
return qlogger
|
||||
}
|
||||
return logging.NewMultiplexedConnectionTracer(qlogger, tr)
|
||||
}
|
||||
return conf
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user