From 89d9984156c3d9fff6349ad87407ed469736b48c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 22 Apr 2025 19:24:46 +0800 Subject: [PATCH] http3: simplify request writer by writing to an io.Writer (#5070) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit It’s always preferable to use the most minimal interface. --- http3/request_writer.go | 5 ++--- http3/request_writer_test.go | 35 ++++++++++------------------------- 2 files changed, 12 insertions(+), 28 deletions(-) diff --git a/http3/request_writer.go b/http3/request_writer.go index 2dbacb150..6d473aa63 100644 --- a/http3/request_writer.go +++ b/http3/request_writer.go @@ -17,7 +17,6 @@ import ( "golang.org/x/net/idna" "github.com/quic-go/qpack" - "github.com/quic-go/quic-go" ) const bodyCopyBufferSize = 8 * 1024 @@ -37,13 +36,13 @@ func newRequestWriter() *requestWriter { } } -func (w *requestWriter) WriteRequestHeader(str quic.Stream, req *http.Request, gzip bool) error { +func (w *requestWriter) WriteRequestHeader(wr io.Writer, req *http.Request, gzip bool) error { // TODO: figure out how to add support for trailers buf := &bytes.Buffer{} if err := w.writeHeaders(buf, req, gzip); err != nil { return err } - if _, err := str.Write(buf.Bytes()); err != nil { + if _, err := wr.Write(buf.Bytes()); err != nil { return err } trace := httptrace.ContextClientTrace(req.Context()) diff --git a/http3/request_writer_test.go b/http3/request_writer_test.go index 5e5d7bd49..85ee512bd 100644 --- a/http3/request_writer_test.go +++ b/http3/request_writer_test.go @@ -7,10 +7,8 @@ import ( "testing" "github.com/quic-go/qpack" - mockquic "github.com/quic-go/quic-go/internal/mocks/quic" "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" ) func decodeRequest(t *testing.T, str io.Reader) map[string]string { @@ -49,12 +47,9 @@ func testRequestWriterGzip(t *testing.T, gzip bool) { req.AddCookie(&http.Cookie{Name: "baz", Value: "lorem ipsum"}) rw := newRequestWriter() - strBuf := &bytes.Buffer{} - mockCtrl := gomock.NewController(t) - str := mockquic.NewMockStream(mockCtrl) - str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() - require.NoError(t, rw.WriteRequestHeader(str, req, gzip)) - headerFields := decodeRequest(t, strBuf) + buf := &bytes.Buffer{} + require.NoError(t, rw.WriteRequestHeader(buf, req, gzip)) + headerFields := decodeRequest(t, buf) require.Equal(t, "quic-go.net", headerFields[":authority"]) require.Equal(t, http.MethodGet, headerFields[":method"]) require.Equal(t, "/index.html?foo=bar", headerFields[":path"]) @@ -73,12 +68,8 @@ func TestRequestWriterInvalidHostHeader(t *testing.T) { require.NoError(t, err) req.Host = "foo@bar" // @ is invalid rw := newRequestWriter() - strBuf := &bytes.Buffer{} - mockCtrl := gomock.NewController(t) - str := mockquic.NewMockStream(mockCtrl) - str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() require.EqualError(t, - rw.WriteRequestHeader(str, req, false), + rw.WriteRequestHeader(&bytes.Buffer{}, req, false), "http3: invalid Host header", ) } @@ -87,12 +78,9 @@ func TestRequestWriterConnect(t *testing.T) { req, err := http.NewRequest(http.MethodConnect, "https://quic-go.net/", nil) require.NoError(t, err) rw := newRequestWriter() - strBuf := &bytes.Buffer{} - mockCtrl := gomock.NewController(t) - str := mockquic.NewMockStream(mockCtrl) - str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() - require.NoError(t, rw.WriteRequestHeader(str, req, false)) - headerFields := decodeRequest(t, strBuf) + buf := &bytes.Buffer{} + require.NoError(t, rw.WriteRequestHeader(buf, req, false)) + headerFields := decodeRequest(t, buf) require.Equal(t, http.MethodConnect, headerFields[":method"]) require.Equal(t, "quic-go.net", headerFields[":authority"]) require.NotContains(t, headerFields, ":path") @@ -105,12 +93,9 @@ func TestRequestWriterExtendedConnect(t *testing.T) { require.NoError(t, err) req.Proto = "webtransport" rw := newRequestWriter() - strBuf := &bytes.Buffer{} - mockCtrl := gomock.NewController(t) - str := mockquic.NewMockStream(mockCtrl) - str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() - require.NoError(t, rw.WriteRequestHeader(str, req, false)) - headerFields := decodeRequest(t, strBuf) + buf := &bytes.Buffer{} + require.NoError(t, rw.WriteRequestHeader(buf, req, false)) + headerFields := decodeRequest(t, buf) require.Equal(t, "quic-go.net", headerFields[":authority"]) require.Equal(t, http.MethodConnect, headerFields[":method"]) require.Equal(t, "/", headerFields[":path"])