From 3ff50295ce73c4b6efcf58330d53798e94528097 Mon Sep 17 00:00:00 2001 From: Robin Thellend Date: Thu, 4 Jan 2024 19:13:53 -0800 Subject: [PATCH] http3: add ConnContext to the server (#4230) * Add ConnContext to http3.Server ConnContext can be used to modify the context used by a new http Request. * Make linter happy * Add nil check and integration test * Add the ServerContextKey check to the ConnContext func * Update integrationtests/self/http_test.go Co-authored-by: Marten Seemann * Update http3/server.go Co-authored-by: Marten Seemann --------- Co-authored-by: Marten Seemann --- http3/server.go | 11 +++++++++++ http3/server_test.go | 5 +++++ integrationtests/self/http_test.go | 31 ++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/http3/server.go b/http3/server.go index ca586f0de..769942c16 100644 --- a/http3/server.go +++ b/http3/server.go @@ -211,6 +211,11 @@ type Server struct { // In that case, the stream type will not be set. UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) + // ConnContext optionally specifies a function that modifies + // the context used for a new connection c. The provided ctx + // has a ServerContextKey value. + ConnContext func(ctx context.Context, c quic.Connection) context.Context + mutex sync.RWMutex listeners map[*QUICEarlyListener]listenerInfo @@ -610,6 +615,12 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q ctx = context.WithValue(ctx, ServerContextKey, s) ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr()) ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr()) + if s.ConnContext != nil { + ctx = s.ConnContext(ctx, conn) + if ctx == nil { + panic("http3: ConnContext returned nil") + } + } req = req.WithContext(ctx) r := newResponseWriter(str, conn, s.logger) if req.Method == http.MethodHead { diff --git a/http3/server_test.go b/http3/server_test.go index 8da806111..ab702c421 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -67,11 +67,15 @@ var _ = Describe("Server", func() { s *Server origQuicListenAddr = quicListenAddr ) + type testConnContextKey string BeforeEach(func() { s = &Server{ TLSConfig: testdata.GetTLSConfig(), logger: utils.DefaultLogger, + ConnContext: func(ctx context.Context, c quic.Connection) context.Context { + return context.WithValue(ctx, testConnContextKey("test"), c) + }, } origQuicListenAddr = quicListenAddr }) @@ -163,6 +167,7 @@ var _ = Describe("Server", func() { Expect(req.Host).To(Equal("www.example.com")) Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) + Expect(req.Context().Value(testConnContextKey("test"))).ToNot(Equal(nil)) }) It("returns 200 with an empty handler", func() { diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 96e72dc7c..cf9f683e4 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -528,4 +528,35 @@ var _ = Describe("HTTP tests", func() { Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) }) + + It("sets conn context", func() { + type ctxKey int + server.ConnContext = func(ctx context.Context, c quic.Connection) context.Context { + serv, ok := ctx.Value(http3.ServerContextKey).(*http3.Server) + Expect(ok).To(BeTrue()) + Expect(serv).To(Equal(server)) + + ctx = context.WithValue(ctx, ctxKey(0), "Hello") + ctx = context.WithValue(ctx, ctxKey(1), c) + return ctx + } + mux.HandleFunc("/conn-context", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + v, ok := r.Context().Value(ctxKey(0)).(string) + Expect(ok).To(BeTrue()) + Expect(v).To(Equal("Hello")) + + c, ok := r.Context().Value(ctxKey(1)).(quic.Connection) + Expect(ok).To(BeTrue()) + Expect(c).ToNot(BeNil()) + + serv, ok := r.Context().Value(http3.ServerContextKey).(*http3.Server) + Expect(ok).To(BeTrue()) + Expect(serv).To(Equal(server)) + }) + + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/conn-context", port)) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + }) })