http3: simplify request writer by writing to an io.Writer (#5070)

It’s always preferable to use the most minimal interface.
This commit is contained in:
Marten Seemann
2025-04-22 19:24:46 +08:00
committed by GitHub
parent f1c7a5df73
commit 89d9984156
2 changed files with 12 additions and 28 deletions

View File

@@ -17,7 +17,6 @@ import (
"golang.org/x/net/idna" "golang.org/x/net/idna"
"github.com/quic-go/qpack" "github.com/quic-go/qpack"
"github.com/quic-go/quic-go"
) )
const bodyCopyBufferSize = 8 * 1024 const bodyCopyBufferSize = 8 * 1024
@@ -37,13 +36,13 @@ func newRequestWriter() *requestWriter {
} }
} }
func (w *requestWriter) WriteRequestHeader(str quic.Stream, req *http.Request, gzip bool) error { func (w *requestWriter) WriteRequestHeader(wr io.Writer, req *http.Request, gzip bool) error {
// TODO: figure out how to add support for trailers // TODO: figure out how to add support for trailers
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
if err := w.writeHeaders(buf, req, gzip); err != nil { if err := w.writeHeaders(buf, req, gzip); err != nil {
return err return err
} }
if _, err := str.Write(buf.Bytes()); err != nil { if _, err := wr.Write(buf.Bytes()); err != nil {
return err return err
} }
trace := httptrace.ContextClientTrace(req.Context()) trace := httptrace.ContextClientTrace(req.Context())

View File

@@ -7,10 +7,8 @@ import (
"testing" "testing"
"github.com/quic-go/qpack" "github.com/quic-go/qpack"
mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
) )
func decodeRequest(t *testing.T, str io.Reader) map[string]string { func decodeRequest(t *testing.T, str io.Reader) map[string]string {
@@ -49,12 +47,9 @@ func testRequestWriterGzip(t *testing.T, gzip bool) {
req.AddCookie(&http.Cookie{Name: "baz", Value: "lorem ipsum"}) req.AddCookie(&http.Cookie{Name: "baz", Value: "lorem ipsum"})
rw := newRequestWriter() rw := newRequestWriter()
strBuf := &bytes.Buffer{} buf := &bytes.Buffer{}
mockCtrl := gomock.NewController(t) require.NoError(t, rw.WriteRequestHeader(buf, req, gzip))
str := mockquic.NewMockStream(mockCtrl) headerFields := decodeRequest(t, buf)
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
require.NoError(t, rw.WriteRequestHeader(str, req, gzip))
headerFields := decodeRequest(t, strBuf)
require.Equal(t, "quic-go.net", headerFields[":authority"]) require.Equal(t, "quic-go.net", headerFields[":authority"])
require.Equal(t, http.MethodGet, headerFields[":method"]) require.Equal(t, http.MethodGet, headerFields[":method"])
require.Equal(t, "/index.html?foo=bar", headerFields[":path"]) require.Equal(t, "/index.html?foo=bar", headerFields[":path"])
@@ -73,12 +68,8 @@ func TestRequestWriterInvalidHostHeader(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
req.Host = "foo@bar" // @ is invalid req.Host = "foo@bar" // @ is invalid
rw := newRequestWriter() rw := newRequestWriter()
strBuf := &bytes.Buffer{}
mockCtrl := gomock.NewController(t)
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
require.EqualError(t, require.EqualError(t,
rw.WriteRequestHeader(str, req, false), rw.WriteRequestHeader(&bytes.Buffer{}, req, false),
"http3: invalid Host header", "http3: invalid Host header",
) )
} }
@@ -87,12 +78,9 @@ func TestRequestWriterConnect(t *testing.T) {
req, err := http.NewRequest(http.MethodConnect, "https://quic-go.net/", nil) req, err := http.NewRequest(http.MethodConnect, "https://quic-go.net/", nil)
require.NoError(t, err) require.NoError(t, err)
rw := newRequestWriter() rw := newRequestWriter()
strBuf := &bytes.Buffer{} buf := &bytes.Buffer{}
mockCtrl := gomock.NewController(t) require.NoError(t, rw.WriteRequestHeader(buf, req, false))
str := mockquic.NewMockStream(mockCtrl) headerFields := decodeRequest(t, buf)
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
require.NoError(t, rw.WriteRequestHeader(str, req, false))
headerFields := decodeRequest(t, strBuf)
require.Equal(t, http.MethodConnect, headerFields[":method"]) require.Equal(t, http.MethodConnect, headerFields[":method"])
require.Equal(t, "quic-go.net", headerFields[":authority"]) require.Equal(t, "quic-go.net", headerFields[":authority"])
require.NotContains(t, headerFields, ":path") require.NotContains(t, headerFields, ":path")
@@ -105,12 +93,9 @@ func TestRequestWriterExtendedConnect(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
req.Proto = "webtransport" req.Proto = "webtransport"
rw := newRequestWriter() rw := newRequestWriter()
strBuf := &bytes.Buffer{} buf := &bytes.Buffer{}
mockCtrl := gomock.NewController(t) require.NoError(t, rw.WriteRequestHeader(buf, req, false))
str := mockquic.NewMockStream(mockCtrl) headerFields := decodeRequest(t, buf)
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
require.NoError(t, rw.WriteRequestHeader(str, req, false))
headerFields := decodeRequest(t, strBuf)
require.Equal(t, "quic-go.net", headerFields[":authority"]) require.Equal(t, "quic-go.net", headerFields[":authority"])
require.Equal(t, http.MethodConnect, headerFields[":method"]) require.Equal(t, http.MethodConnect, headerFields[":method"])
require.Equal(t, "/", headerFields[":path"]) require.Equal(t, "/", headerFields[":path"])