http3: simplify composition of the HTTP stream and request stream (#4433)

This commit is contained in:
Marten Seemann
2024-04-13 17:18:51 -07:00
committed by GitHub
parent 90627f6f7c
commit 25cd4b5d24
4 changed files with 24 additions and 32 deletions

View File

@@ -222,14 +222,12 @@ func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) (
return nil, err
}
return newRequestStream(
newStream(str, func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") }),
c.hconn,
newStream(str, c.hconn),
c.requestWriter,
nil,
c.decoder,
c.DisableCompression,
c.maxHeaderBytes(),
func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") },
), nil
}
@@ -274,14 +272,12 @@ func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.Read
func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str quic.Stream, reqDone chan<- struct{}) (*http.Response, error) {
hstr := newRequestStream(
newStream(str, func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") }),
c.hconn,
newStream(str, c.hconn),
c.requestWriter,
reqDone,
c.decoder,
c.DisableCompression,
c.maxHeaderBytes(),
func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") },
)
if err := hstr.SendRequestHeader(req); err != nil {
return nil, err

View File

@@ -35,7 +35,7 @@ type RequestStream interface {
type stream struct {
quic.Stream
closeConnection func(ErrCode)
conn quic.Connection
buf []byte // used as a temporary buffer when writing the HTTP/3 frame headers
@@ -44,10 +44,10 @@ type stream struct {
var _ Stream = &stream{}
func newStream(str quic.Stream, closeConnection func(ErrCode)) *stream {
func newStream(str quic.Stream, conn quic.Connection) *stream {
return &stream{
Stream: str,
closeConnection: closeConnection,
conn: conn,
buf: make([]byte, 0, 16),
}
}
@@ -68,7 +68,7 @@ func (s *stream) Read(b []byte) (int, error) {
s.bytesRemainingInFrame = f.Length
break parseLoop
default:
s.closeConnection(ErrCodeFrameUnexpected)
s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
// parseNextFrame skips over unknown frame types
// Therefore, this condition is only entered when we parsed another known frame type.
return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
@@ -105,13 +105,10 @@ func (s *stream) Write(b []byte) (int, error) {
type requestStream struct {
*stream
conn quic.Connection
responseBody io.ReadCloser // set by ReadResponse
decoder *qpack.Decoder
requestWriter *requestWriter
closeConnection func(ErrCode)
maxHeaderBytes uint64
reqDone chan<- struct{}
disableCompression bool
@@ -125,22 +122,18 @@ var _ RequestStream = &requestStream{}
func newRequestStream(
str *stream,
conn quic.Connection,
requestWriter *requestWriter,
reqDone chan<- struct{},
decoder *qpack.Decoder,
disableCompression bool,
maxHeaderBytes uint64,
closeConnection func(ErrCode),
) *requestStream {
return &requestStream{
stream: str,
conn: conn,
requestWriter: requestWriter,
reqDone: reqDone,
decoder: decoder,
disableCompression: disableCompression,
closeConnection: closeConnection,
maxHeaderBytes: maxHeaderBytes,
}
}
@@ -174,7 +167,7 @@ func (s *requestStream) ReadResponse() (*http.Response, error) {
}
hf, ok := frame.(*headersFrame)
if !ok {
s.closeConnection(ErrCodeFrameUnexpected)
s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame")
return nil, errors.New("http3: expected first frame to be a HEADERS frame")
}
if hf.Length > s.maxHeaderBytes {
@@ -191,7 +184,7 @@ func (s *requestStream) ReadResponse() (*http.Response, error) {
hfs, err := s.decoder.DecodeFull(headerBlock)
if err != nil {
// TODO: use the right error code
s.closeConnection(ErrCodeGeneralProtocolError)
s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeGeneralProtocolError), "")
return nil, fmt.Errorf("http3: failed to decode response headers: %w", err)
}

View File

@@ -6,11 +6,13 @@ import (
"math"
"net/http"
"github.com/quic-go/qpack"
"github.com/quic-go/quic-go"
mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/qpack"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"go.uber.org/mock/gomock"
@@ -30,15 +32,18 @@ var _ = Describe("Stream", func() {
errorCbCalled bool
)
errorCb := func(ErrCode) { errorCbCalled = true }
BeforeEach(func() {
buf = &bytes.Buffer{}
errorCbCalled = false
qstr = mockquic.NewMockStream(mockCtrl)
qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str = newStream(qstr, errorCb)
conn := mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(qerr.ApplicationErrorCode, string) error {
errorCbCalled = true
return nil
}).AnyTimes()
str = newStream(qstr, conn)
})
It("reads DATA frames in a single run", func() {
@@ -167,7 +172,7 @@ var _ = Describe("length-limited streams", func() {
qstr = mockquic.NewMockStream(mockCtrl)
qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str = newStream(qstr, func(ErrCode) { Fail("didn't expect error callback to be called") })
str = newStream(qstr, mockquic.NewMockEarlyConnection(mockCtrl))
})
It("reads all frames", func() {
@@ -213,15 +218,14 @@ var _ = Describe("Request Stream", func() {
BeforeEach(func() {
qstr = mockquic.NewMockStream(mockCtrl)
requestWriter := newRequestWriter(utils.DefaultLogger)
conn := mockquic.NewMockEarlyConnection(mockCtrl)
str = newRequestStream(
newStream(qstr, func(code ErrCode) { Fail("errored") }),
mockquic.NewMockEarlyConnection(mockCtrl),
newStream(qstr, conn),
requestWriter,
make(chan struct{}),
qpack.NewDecoder(func(qpack.HeaderField) {}),
true,
math.MaxUint64,
func(code ErrCode) { Fail("errored") },
)
})

View File

@@ -528,11 +528,10 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
// Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
// See section 4.1.2 of RFC 9114.
var httpStr Stream
closeConnection := func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") }
if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
httpStr = newLengthLimitedStream(newStream(str, closeConnection), req.ContentLength)
httpStr = newLengthLimitedStream(newStream(str, conn), req.ContentLength)
} else {
httpStr = newStream(str, closeConnection)
httpStr = newStream(str, conn)
}
body := newRequestBody(httpStr, conn.Context(), conn.ReceivedSettings(), conn.Settings)
req.Body = body