diff --git a/http3/body.go b/http3/body.go index dc168e9e..6beb2037 100644 --- a/http3/body.go +++ b/http3/body.go @@ -37,27 +37,28 @@ type Hijacker interface { StreamCreator() StreamCreator } -// The body of a http.Request or http.Response. +// Settingser allows the server to retrieve the client's SETTINGS. +// The http.Request.Body implements this interface. +type Settingser interface { + // Settings returns the client's HTTP settings. + // It blocks until the SETTINGS frame has been received. + // Note that it is not guaranteed that this happens during the lifetime of the request. + Settings(context.Context) (*Settings, error) +} + +// The body is used in the requestBody (for a http.Request) and the responseBody (for a http.Response). type body struct { str quic.Stream wasHijacked bool // set when HTTPStream is called } -var ( - _ io.ReadCloser = &body{} - _ HTTPStreamer = &body{} -) - -func newRequestBody(str Stream) *body { - return &body{str: str} -} - func (r *body) HTTPStream() Stream { r.wasHijacked = true return r.str } +func (r *body) StreamID() quic.StreamID { return r.str.StreamID() } func (r *body) wasStreamHijacked() bool { return r.wasHijacked } @@ -72,6 +73,39 @@ func (r *body) Close() error { return nil } +type requestBody struct { + body + connCtx context.Context + rcvdSettings <-chan struct{} + getSettings func() *Settings +} + +var ( + _ io.ReadCloser = &requestBody{} + _ HTTPStreamer = &requestBody{} + _ Settingser = &requestBody{} +) + +func newRequestBody(str Stream, connCtx context.Context, rcvdSettings <-chan struct{}, getSettings func() *Settings) *requestBody { + return &requestBody{ + body: body{str: str}, + connCtx: connCtx, + rcvdSettings: rcvdSettings, + getSettings: getSettings, + } +} + +func (r *requestBody) Settings(ctx context.Context) (*Settings, error) { + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + case <-r.connCtx.Done(): + return nil, context.Cause(r.connCtx) + case <-r.rcvdSettings: + return r.getSettings(), nil + } +} + type hijackableBody struct { body conn quic.Connection // only needed to implement Hijacker @@ -84,24 +118,19 @@ type hijackableBody struct { } var ( - _ Hijacker = &hijackableBody{} - _ HTTPStreamer = &hijackableBody{} + _ io.ReadCloser = &hijackableBody{} + _ Hijacker = &hijackableBody{} + _ HTTPStreamer = &hijackableBody{} ) func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody { return &hijackableBody{ - body: body{ - str: str, - }, + body: body{str: str}, reqDone: done, conn: conn, } } -func (r *hijackableBody) StreamCreator() StreamCreator { - return r.conn -} - func (r *hijackableBody) Read(b []byte) (int, error) { n, err := r.str.Read(b) if err != nil { @@ -120,10 +149,6 @@ func (r *hijackableBody) requestDone() { r.reqDoneClosed = true } -func (r *body) StreamID() quic.StreamID { - return r.str.StreamID() -} - func (r *hijackableBody) Close() error { r.requestDone() // If the EOF was read, CancelRead() is a no-op. @@ -131,6 +156,5 @@ func (r *hijackableBody) Close() error { return nil } -func (r *hijackableBody) HTTPStream() Stream { - return r.str -} +func (r *hijackableBody) HTTPStream() Stream { return r.str } +func (r *hijackableBody) StreamCreator() StreamCreator { return r.conn } diff --git a/http3/body_test.go b/http3/body_test.go index 7b96345d..4ec08209 100644 --- a/http3/body_test.go +++ b/http3/body_test.go @@ -1,6 +1,7 @@ package http3 import ( + "context" "errors" "github.com/quic-go/quic-go" @@ -11,6 +12,29 @@ import ( "go.uber.org/mock/gomock" ) +var _ = Describe("Request Body", func() { + It("makes the SETTINGS available", func() { + str := mockquic.NewMockStream(mockCtrl) + rcvdSettings := make(chan struct{}) + close(rcvdSettings) + settings := &Settings{EnableExtendedConnect: true} + body := newRequestBody(str, context.Background(), rcvdSettings, func() *Settings { return settings }) + s, err := body.Settings(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(s).To(Equal(settings)) + }) + + It("unblocks Settings() when the connection is closed", func() { + str := mockquic.NewMockStream(mockCtrl) + ctx, cancel := context.WithCancelCause(context.Background()) + testErr := errors.New("test error") + cancel(testErr) + body := newRequestBody(str, ctx, make(chan struct{}), func() *Settings { return nil }) + _, err := body.Settings(context.Background()) + Expect(err).To(MatchError(testErr)) + }) +}) + var _ = Describe("Response Body", func() { var reqDone chan struct{} diff --git a/http3/conn.go b/http3/conn.go index a2da426d..cae1f2ce 100644 --- a/http3/conn.go +++ b/http3/conn.go @@ -11,7 +11,8 @@ import ( ) type connection struct { - quicConn quic.Connection + quic.Connection + perspective protocol.Perspective logger utils.Logger @@ -30,7 +31,7 @@ func newConnection( logger utils.Logger, ) *connection { return &connection{ - quicConn: quicConn, + Connection: quicConn, perspective: perspective, logger: logger, enableDatagrams: enableDatagrams, @@ -47,7 +48,7 @@ func (c *connection) HandleUnidirectionalStreams() { ) for { - str, err := c.quicConn.AcceptUniStream(context.Background()) + str, err := c.Connection.AcceptUniStream(context.Background()) if err != nil { c.logger.Debugf("accepting unidirectional stream failed: %s", err) return @@ -56,7 +57,7 @@ func (c *connection) HandleUnidirectionalStreams() { go func(str quic.ReceiveStream) { streamType, err := quicvarint.Read(quicvarint.NewReader(str)) if err != nil { - if c.uniStreamHijacker != nil && c.uniStreamHijacker(StreamType(streamType), c.quicConn, str, err) { + if c.uniStreamHijacker != nil && c.uniStreamHijacker(StreamType(streamType), c.Connection, str, err) { return } c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) @@ -67,13 +68,13 @@ func (c *connection) HandleUnidirectionalStreams() { case streamTypeControlStream: case streamTypeQPACKEncoderStream: if isFirst := rcvdQPACKEncoderStr.CompareAndSwap(false, true); !isFirst { - c.quicConn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream") + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream") } // Our QPACK implementation doesn't use the dynamic table yet. return case streamTypeQPACKDecoderStream: if isFirst := rcvdQPACKDecoderStr.CompareAndSwap(false, true); !isFirst { - c.quicConn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream") + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream") } // Our QPACK implementation doesn't use the dynamic table yet. return @@ -81,14 +82,14 @@ func (c *connection) HandleUnidirectionalStreams() { switch c.perspective { case protocol.PerspectiveClient: // we never increased the Push ID, so we don't expect any push streams - c.quicConn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") case protocol.PerspectiveServer: // only the server can push - c.quicConn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "") + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "") } return default: - if c.uniStreamHijacker != nil && c.uniStreamHijacker(StreamType(streamType), c.quicConn, str, nil) { + if c.uniStreamHijacker != nil && c.uniStreamHijacker(StreamType(streamType), c.Connection, str, nil) { return } str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) @@ -96,17 +97,17 @@ func (c *connection) HandleUnidirectionalStreams() { } // Only a single control stream is allowed. if isFirstControlStr := rcvdControlStr.CompareAndSwap(false, true); !isFirstControlStr { - c.quicConn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") return } f, err := parseNextFrame(str, nil) if err != nil { - c.quicConn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") return } sf, ok := f.(*settingsFrame) if !ok { - c.quicConn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") return } c.settings = &Settings{ @@ -123,8 +124,8 @@ func (c *connection) HandleUnidirectionalStreams() { // If datagram support was enabled on our side as well as on the server side, // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). - if c.enableDatagrams && !c.quicConn.ConnectionState().SupportsDatagrams { - c.quicConn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") + if c.enableDatagrams && !c.Connection.ConnectionState().SupportsDatagrams { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") } }(str) } diff --git a/http3/server.go b/http3/server.go index 6ce186e7..fa3ffbcf 100644 --- a/http3/server.go +++ b/http3/server.go @@ -478,7 +478,7 @@ func (s *Server) handleConn(conn quic.Connection) error { return fmt.Errorf("accepting stream failed: %w", err) } go func() { - rerr := s.handleRequest(conn, str, decoder, func() { + rerr := s.handleRequest(hconn, str, decoder, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") }) if rerr.err == errHijacked { @@ -510,7 +510,7 @@ func (s *Server) maxHeaderBytes() uint64 { return uint64(s.MaxHeaderBytes) } -func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError { +func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError { var ufh unknownFrameHandlerFunc if s.StreamHijacker != nil { ufh = func(ft FrameType, e error) (processed bool, err error) { return s.StreamHijacker(ft, conn, str, e) } @@ -555,7 +555,7 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q } else { httpStr = newStream(str, onFrameError) } - body := newRequestBody(httpStr) + body := newRequestBody(httpStr, conn.Context(), conn.ReceivedSettings(), conn.Settings) req.Body = body if s.logger.Debug() { diff --git a/http3/server_test.go b/http3/server_test.go index c7ea5a2f..789770f1 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -87,7 +87,7 @@ var _ = Describe("Server", func() { var ( qpackDecoder *qpack.Decoder str *mockquic.MockStream - conn *mockquic.MockEarlyConnection + conn *connection exampleGetRequest *http.Request examplePostRequest *http.Request ) @@ -140,11 +140,13 @@ var _ = Describe("Server", func() { qpackDecoder = qpack.NewDecoder(nil) str = mockquic.NewMockStream(mockCtrl) - conn = mockquic.NewMockEarlyConnection(mockCtrl) + qconn := mockquic.NewMockEarlyConnection(mockCtrl) addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() - conn.EXPECT().LocalAddr().AnyTimes() - conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() + qconn.EXPECT().RemoteAddr().Return(addr).AnyTimes() + qconn.EXPECT().LocalAddr().AnyTimes() + qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() + qconn.EXPECT().Context().Return(context.Background()).AnyTimes() + conn = newConnection(qconn, false, nil, protocol.PerspectiveServer, utils.DefaultLogger) }) It("calls the HTTP handler function", func() { @@ -514,6 +516,7 @@ var _ = Describe("Server", func() { conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() conn.EXPECT().LocalAddr().AnyTimes() conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() + conn.EXPECT().Context().Return(context.Background()).AnyTimes() }) AfterEach(func() { testDone <- struct{}{} }) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 8e57d906..89fe0fde 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -579,6 +579,41 @@ var _ = Describe("HTTP tests", func() { Expect(err).To(MatchError(err)) }) + It("receives the client's settings", func() { + settingsChan := make(chan *http3.Settings, 1) + mux.HandleFunc("/settings", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + // The http.Request.Body is guaranteed to implement the http3.Settingser interface. + settings, err := r.Body.(http3.Settingser).Settings(context.Background()) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + settingsChan <- settings + w.WriteHeader(http.StatusOK) + }) + + rt = &http3.RoundTripper{ + TLSClientConfig: getTLSClientConfigWithoutServerName(), + QUICConfig: getQuicConfig(&quic.Config{ + MaxIdleTimeout: 10 * time.Second, + EnableDatagrams: true, + }), + EnableDatagrams: true, + AdditionalSettings: map[uint64]uint64{1337: 42}, + } + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/settings", port), nil) + Expect(err).ToNot(HaveOccurred()) + + _, err = rt.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + var settings *http3.Settings + Eventually(settingsChan).Should(Receive(&settings)) + Expect(settings.EnableDatagram).To(BeTrue()) + Expect(settings.EnableExtendedConnect).To(BeFalse()) + Expect(settings.Other).To(HaveKeyWithValue(uint64(1337), uint64(42))) + }) + Context("0-RTT", func() { runCountingProxy := func(serverPort int, rtt time.Duration) (*quicproxy.QuicProxy, *atomic.Uint32) { var num0RTTPackets atomic.Uint32