From 25cd4b5d24e42f9bd5630526ee57f35b7f4a9f98 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 13 Apr 2024 17:18:51 -0700 Subject: [PATCH] http3: simplify composition of the HTTP stream and request stream (#4433) --- http3/client.go | 8 ++------ http3/http_stream.go | 23 ++++++++--------------- http3/http_stream_test.go | 20 ++++++++++++-------- http3/server.go | 5 ++--- 4 files changed, 24 insertions(+), 32 deletions(-) diff --git a/http3/client.go b/http3/client.go index e61cf6e6..6504d987 100644 --- a/http3/client.go +++ b/http3/client.go @@ -222,14 +222,12 @@ func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) ( return nil, err } return newRequestStream( - newStream(str, func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") }), - c.hconn, + newStream(str, c.hconn), c.requestWriter, nil, c.decoder, c.DisableCompression, c.maxHeaderBytes(), - func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") }, ), nil } @@ -274,14 +272,12 @@ func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.Read func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str quic.Stream, reqDone chan<- struct{}) (*http.Response, error) { hstr := newRequestStream( - newStream(str, func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") }), - c.hconn, + newStream(str, c.hconn), c.requestWriter, reqDone, c.decoder, c.DisableCompression, c.maxHeaderBytes(), - func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") }, ) if err := hstr.SendRequestHeader(req); err != nil { return nil, err diff --git a/http3/http_stream.go b/http3/http_stream.go index f0af1f4b..10955c0d 100644 --- a/http3/http_stream.go +++ b/http3/http_stream.go @@ -35,7 +35,7 @@ type RequestStream interface { type stream struct { quic.Stream - closeConnection func(ErrCode) + conn quic.Connection buf []byte // used as a temporary buffer when writing the HTTP/3 frame headers @@ -44,11 +44,11 @@ type stream struct { var _ Stream = &stream{} -func newStream(str quic.Stream, closeConnection func(ErrCode)) *stream { +func newStream(str quic.Stream, conn quic.Connection) *stream { return &stream{ - Stream: str, - closeConnection: closeConnection, - buf: make([]byte, 0, 16), + Stream: str, + conn: conn, + buf: make([]byte, 0, 16), } } @@ -68,7 +68,7 @@ func (s *stream) Read(b []byte) (int, error) { s.bytesRemainingInFrame = f.Length break parseLoop default: - s.closeConnection(ErrCodeFrameUnexpected) + s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") // parseNextFrame skips over unknown frame types // Therefore, this condition is only entered when we parsed another known frame type. return 0, fmt.Errorf("peer sent an unexpected frame: %T", f) @@ -105,13 +105,10 @@ func (s *stream) Write(b []byte) (int, error) { type requestStream struct { *stream - conn quic.Connection - responseBody io.ReadCloser // set by ReadResponse decoder *qpack.Decoder requestWriter *requestWriter - closeConnection func(ErrCode) maxHeaderBytes uint64 reqDone chan<- struct{} disableCompression bool @@ -125,22 +122,18 @@ var _ RequestStream = &requestStream{} func newRequestStream( str *stream, - conn quic.Connection, requestWriter *requestWriter, reqDone chan<- struct{}, decoder *qpack.Decoder, disableCompression bool, maxHeaderBytes uint64, - closeConnection func(ErrCode), ) *requestStream { return &requestStream{ stream: str, - conn: conn, requestWriter: requestWriter, reqDone: reqDone, decoder: decoder, disableCompression: disableCompression, - closeConnection: closeConnection, maxHeaderBytes: maxHeaderBytes, } } @@ -174,7 +167,7 @@ func (s *requestStream) ReadResponse() (*http.Response, error) { } hf, ok := frame.(*headersFrame) if !ok { - s.closeConnection(ErrCodeFrameUnexpected) + s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame") return nil, errors.New("http3: expected first frame to be a HEADERS frame") } if hf.Length > s.maxHeaderBytes { @@ -191,7 +184,7 @@ func (s *requestStream) ReadResponse() (*http.Response, error) { hfs, err := s.decoder.DecodeFull(headerBlock) if err != nil { // TODO: use the right error code - s.closeConnection(ErrCodeGeneralProtocolError) + s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeGeneralProtocolError), "") return nil, fmt.Errorf("http3: failed to decode response headers: %w", err) } diff --git a/http3/http_stream_test.go b/http3/http_stream_test.go index e99115f5..d5f8446f 100644 --- a/http3/http_stream_test.go +++ b/http3/http_stream_test.go @@ -6,11 +6,13 @@ import ( "math" "net/http" - "github.com/quic-go/qpack" "github.com/quic-go/quic-go" mockquic "github.com/quic-go/quic-go/internal/mocks/quic" + "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/qpack" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "go.uber.org/mock/gomock" @@ -30,15 +32,18 @@ var _ = Describe("Stream", func() { errorCbCalled bool ) - errorCb := func(ErrCode) { errorCbCalled = true } - BeforeEach(func() { buf = &bytes.Buffer{} errorCbCalled = false qstr = mockquic.NewMockStream(mockCtrl) qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - str = newStream(qstr, errorCb) + conn := mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(qerr.ApplicationErrorCode, string) error { + errorCbCalled = true + return nil + }).AnyTimes() + str = newStream(qstr, conn) }) It("reads DATA frames in a single run", func() { @@ -167,7 +172,7 @@ var _ = Describe("length-limited streams", func() { qstr = mockquic.NewMockStream(mockCtrl) qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - str = newStream(qstr, func(ErrCode) { Fail("didn't expect error callback to be called") }) + str = newStream(qstr, mockquic.NewMockEarlyConnection(mockCtrl)) }) It("reads all frames", func() { @@ -213,15 +218,14 @@ var _ = Describe("Request Stream", func() { BeforeEach(func() { qstr = mockquic.NewMockStream(mockCtrl) requestWriter := newRequestWriter(utils.DefaultLogger) + conn := mockquic.NewMockEarlyConnection(mockCtrl) str = newRequestStream( - newStream(qstr, func(code ErrCode) { Fail("errored") }), - mockquic.NewMockEarlyConnection(mockCtrl), + newStream(qstr, conn), requestWriter, make(chan struct{}), qpack.NewDecoder(func(qpack.HeaderField) {}), true, math.MaxUint64, - func(code ErrCode) { Fail("errored") }, ) }) diff --git a/http3/server.go b/http3/server.go index 86666297..ef545413 100644 --- a/http3/server.go +++ b/http3/server.go @@ -528,11 +528,10 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack // Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set). // See section 4.1.2 of RFC 9114. var httpStr Stream - closeConnection := func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") } if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 { - httpStr = newLengthLimitedStream(newStream(str, closeConnection), req.ContentLength) + httpStr = newLengthLimitedStream(newStream(str, conn), req.ContentLength) } else { - httpStr = newStream(str, closeConnection) + httpStr = newStream(str, conn) } body := newRequestBody(httpStr, conn.Context(), conn.ReceivedSettings(), conn.Settings) req.Body = body