diff --git a/http3/client.go b/http3/client.go index a32844e38..501e3db57 100644 --- a/http3/client.go +++ b/http3/client.go @@ -2,6 +2,7 @@ package http3 import ( "context" + "errors" "fmt" "io" "log/slog" @@ -180,6 +181,21 @@ func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Resp } } + // It is only possible to send an Extended CONNECT request once the SETTINGS were received. + // See section 3 of RFC 8441. + if isExtendedConnectRequest(req) { + connCtx := c.Connection.Context() + // wait for the server's SETTINGS frame to arrive + select { + case <-c.hconn.ReceivedSettings(): + case <-connCtx.Done(): + return nil, context.Cause(connCtx) + } + if !c.hconn.Settings().EnableExtendedConnect { + return nil, errors.New("http3: server didn't enable Extended CONNECT") + } + } + str, err := c.Connection.OpenStreamSync(req.Context()) if err != nil { return nil, err diff --git a/http3/client_test.go b/http3/client_test.go index 2c7eb4bfc..7adc62e85 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -375,6 +375,73 @@ var _ = Describe("Client", func() { conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) close(done) }) + + It("checks the server's SETTINGS before sending an Extended CONNECT request", func() { + done := make(chan struct{}) + conn := mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().OpenUniStream().DoAndReturn(func() (quic.SendStream, error) { + <-done + return nil, errors.New("test done") + }).MaxTimes(1) + b := quicvarint.Append(nil, streamTypeControlStream) + b = (&settingsFrame{ExtendedConnect: true}).Append(b) + r := bytes.NewReader(b) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-done + return nil, errors.New("test done") + }) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) + conn.EXPECT().Context().Return(context.Background()) + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("test error")) + + rt := &SingleDestinationRoundTripper{Connection: conn} + _, err := rt.RoundTrip(&http.Request{ + Method: http.MethodConnect, + Proto: "connect", + Host: "localhost", + }) + Expect(err).To(MatchError("test error")) + + // test shutdown + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) + close(done) + }) + + It("rejects Extended CONNECT requests if the server doesn't enable it", func() { + done := make(chan struct{}) + conn := mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().OpenUniStream().DoAndReturn(func() (quic.SendStream, error) { + <-done + return nil, errors.New("test done") + }).MaxTimes(1) + b := quicvarint.Append(nil, streamTypeControlStream) + b = (&settingsFrame{Datagram: true}).Append(b) + r := bytes.NewReader(b) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-done + return nil, errors.New("test done") + }) + conn.EXPECT().HandshakeComplete().Return(handshakeChan) + conn.EXPECT().Context().Return(context.Background()) + + rt := &SingleDestinationRoundTripper{Connection: conn} + _, err := rt.RoundTrip(&http.Request{ + Method: http.MethodConnect, + Proto: "connect", + Host: "localhost", + }) + Expect(err).To(MatchError("http3: server didn't enable Extended CONNECT")) + + // test shutdown + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) + close(done) + }) }) Context("Doing requests", func() { diff --git a/http3/request_writer.go b/http3/request_writer.go index 995b8d7a7..e83e9cc81 100644 --- a/http3/request_writer.go +++ b/http3/request_writer.go @@ -65,6 +65,10 @@ func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) return err } +func isExtendedConnectRequest(req *http.Request) bool { + return req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1" +} + // copied from net/transport.go // Modified to support Extended CONNECT: // Contrary to what the godoc for the http.Request says, @@ -83,7 +87,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra } // http.NewRequest sets this field to HTTP/1.1 - isExtendedConnect := req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1" + isExtendedConnect := isExtendedConnectRequest(req) var path string if req.Method != http.MethodConnect || isExtendedConnect {