forked from quic-go/quic-go
http3: add ConnContext to the server (#4230)
* Add ConnContext to http3.Server ConnContext can be used to modify the context used by a new http Request. * Make linter happy * Add nil check and integration test * Add the ServerContextKey check to the ConnContext func * Update integrationtests/self/http_test.go Co-authored-by: Marten Seemann <martenseemann@gmail.com> * Update http3/server.go Co-authored-by: Marten Seemann <martenseemann@gmail.com> --------- Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
@@ -211,6 +211,11 @@ type Server struct {
|
|||||||
// In that case, the stream type will not be set.
|
// In that case, the stream type will not be set.
|
||||||
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
|
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
|
||||||
|
|
||||||
|
// ConnContext optionally specifies a function that modifies
|
||||||
|
// the context used for a new connection c. The provided ctx
|
||||||
|
// has a ServerContextKey value.
|
||||||
|
ConnContext func(ctx context.Context, c quic.Connection) context.Context
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
listeners map[*QUICEarlyListener]listenerInfo
|
listeners map[*QUICEarlyListener]listenerInfo
|
||||||
|
|
||||||
@@ -610,6 +615,12 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
|
|||||||
ctx = context.WithValue(ctx, ServerContextKey, s)
|
ctx = context.WithValue(ctx, ServerContextKey, s)
|
||||||
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
|
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
|
||||||
ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr())
|
ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr())
|
||||||
|
if s.ConnContext != nil {
|
||||||
|
ctx = s.ConnContext(ctx, conn)
|
||||||
|
if ctx == nil {
|
||||||
|
panic("http3: ConnContext returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
r := newResponseWriter(str, conn, s.logger)
|
r := newResponseWriter(str, conn, s.logger)
|
||||||
if req.Method == http.MethodHead {
|
if req.Method == http.MethodHead {
|
||||||
|
|||||||
@@ -67,11 +67,15 @@ var _ = Describe("Server", func() {
|
|||||||
s *Server
|
s *Server
|
||||||
origQuicListenAddr = quicListenAddr
|
origQuicListenAddr = quicListenAddr
|
||||||
)
|
)
|
||||||
|
type testConnContextKey string
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
s = &Server{
|
s = &Server{
|
||||||
TLSConfig: testdata.GetTLSConfig(),
|
TLSConfig: testdata.GetTLSConfig(),
|
||||||
logger: utils.DefaultLogger,
|
logger: utils.DefaultLogger,
|
||||||
|
ConnContext: func(ctx context.Context, c quic.Connection) context.Context {
|
||||||
|
return context.WithValue(ctx, testConnContextKey("test"), c)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
origQuicListenAddr = quicListenAddr
|
origQuicListenAddr = quicListenAddr
|
||||||
})
|
})
|
||||||
@@ -163,6 +167,7 @@ var _ = Describe("Server", func() {
|
|||||||
Expect(req.Host).To(Equal("www.example.com"))
|
Expect(req.Host).To(Equal("www.example.com"))
|
||||||
Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337"))
|
Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337"))
|
||||||
Expect(req.Context().Value(ServerContextKey)).To(Equal(s))
|
Expect(req.Context().Value(ServerContextKey)).To(Equal(s))
|
||||||
|
Expect(req.Context().Value(testConnContextKey("test"))).ToNot(Equal(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns 200 with an empty handler", func() {
|
It("returns 200 with an empty handler", func() {
|
||||||
|
|||||||
@@ -528,4 +528,35 @@ var _ = Describe("HTTP tests", func() {
|
|||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(resp.StatusCode).To(Equal(200))
|
Expect(resp.StatusCode).To(Equal(200))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("sets conn context", func() {
|
||||||
|
type ctxKey int
|
||||||
|
server.ConnContext = func(ctx context.Context, c quic.Connection) context.Context {
|
||||||
|
serv, ok := ctx.Value(http3.ServerContextKey).(*http3.Server)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(serv).To(Equal(server))
|
||||||
|
|
||||||
|
ctx = context.WithValue(ctx, ctxKey(0), "Hello")
|
||||||
|
ctx = context.WithValue(ctx, ctxKey(1), c)
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
mux.HandleFunc("/conn-context", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
v, ok := r.Context().Value(ctxKey(0)).(string)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(v).To(Equal("Hello"))
|
||||||
|
|
||||||
|
c, ok := r.Context().Value(ctxKey(1)).(quic.Connection)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(c).ToNot(BeNil())
|
||||||
|
|
||||||
|
serv, ok := r.Context().Value(http3.ServerContextKey).(*http3.Server)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
Expect(serv).To(Equal(server))
|
||||||
|
})
|
||||||
|
|
||||||
|
resp, err := client.Get(fmt.Sprintf("https://localhost:%d/conn-context", port))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.StatusCode).To(Equal(200))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user