forked from quic-go/quic-go
http3: simplify composition of the HTTP stream and request stream (#4433)
This commit is contained in:
@@ -222,14 +222,12 @@ func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) (
|
||||
return nil, err
|
||||
}
|
||||
return newRequestStream(
|
||||
newStream(str, func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") }),
|
||||
c.hconn,
|
||||
newStream(str, c.hconn),
|
||||
c.requestWriter,
|
||||
nil,
|
||||
c.decoder,
|
||||
c.DisableCompression,
|
||||
c.maxHeaderBytes(),
|
||||
func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") },
|
||||
), nil
|
||||
}
|
||||
|
||||
@@ -274,14 +272,12 @@ func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.Read
|
||||
|
||||
func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str quic.Stream, reqDone chan<- struct{}) (*http.Response, error) {
|
||||
hstr := newRequestStream(
|
||||
newStream(str, func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") }),
|
||||
c.hconn,
|
||||
newStream(str, c.hconn),
|
||||
c.requestWriter,
|
||||
reqDone,
|
||||
c.decoder,
|
||||
c.DisableCompression,
|
||||
c.maxHeaderBytes(),
|
||||
func(e ErrCode) { c.Connection.CloseWithError(quic.ApplicationErrorCode(e), "") },
|
||||
)
|
||||
if err := hstr.SendRequestHeader(req); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -35,7 +35,7 @@ type RequestStream interface {
|
||||
|
||||
type stream struct {
|
||||
quic.Stream
|
||||
closeConnection func(ErrCode)
|
||||
conn quic.Connection
|
||||
|
||||
buf []byte // used as a temporary buffer when writing the HTTP/3 frame headers
|
||||
|
||||
@@ -44,10 +44,10 @@ type stream struct {
|
||||
|
||||
var _ Stream = &stream{}
|
||||
|
||||
func newStream(str quic.Stream, closeConnection func(ErrCode)) *stream {
|
||||
func newStream(str quic.Stream, conn quic.Connection) *stream {
|
||||
return &stream{
|
||||
Stream: str,
|
||||
closeConnection: closeConnection,
|
||||
conn: conn,
|
||||
buf: make([]byte, 0, 16),
|
||||
}
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func (s *stream) Read(b []byte) (int, error) {
|
||||
s.bytesRemainingInFrame = f.Length
|
||||
break parseLoop
|
||||
default:
|
||||
s.closeConnection(ErrCodeFrameUnexpected)
|
||||
s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
|
||||
// parseNextFrame skips over unknown frame types
|
||||
// Therefore, this condition is only entered when we parsed another known frame type.
|
||||
return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
|
||||
@@ -105,13 +105,10 @@ func (s *stream) Write(b []byte) (int, error) {
|
||||
type requestStream struct {
|
||||
*stream
|
||||
|
||||
conn quic.Connection
|
||||
|
||||
responseBody io.ReadCloser // set by ReadResponse
|
||||
|
||||
decoder *qpack.Decoder
|
||||
requestWriter *requestWriter
|
||||
closeConnection func(ErrCode)
|
||||
maxHeaderBytes uint64
|
||||
reqDone chan<- struct{}
|
||||
disableCompression bool
|
||||
@@ -125,22 +122,18 @@ var _ RequestStream = &requestStream{}
|
||||
|
||||
func newRequestStream(
|
||||
str *stream,
|
||||
conn quic.Connection,
|
||||
requestWriter *requestWriter,
|
||||
reqDone chan<- struct{},
|
||||
decoder *qpack.Decoder,
|
||||
disableCompression bool,
|
||||
maxHeaderBytes uint64,
|
||||
closeConnection func(ErrCode),
|
||||
) *requestStream {
|
||||
return &requestStream{
|
||||
stream: str,
|
||||
conn: conn,
|
||||
requestWriter: requestWriter,
|
||||
reqDone: reqDone,
|
||||
decoder: decoder,
|
||||
disableCompression: disableCompression,
|
||||
closeConnection: closeConnection,
|
||||
maxHeaderBytes: maxHeaderBytes,
|
||||
}
|
||||
}
|
||||
@@ -174,7 +167,7 @@ func (s *requestStream) ReadResponse() (*http.Response, error) {
|
||||
}
|
||||
hf, ok := frame.(*headersFrame)
|
||||
if !ok {
|
||||
s.closeConnection(ErrCodeFrameUnexpected)
|
||||
s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame")
|
||||
return nil, errors.New("http3: expected first frame to be a HEADERS frame")
|
||||
}
|
||||
if hf.Length > s.maxHeaderBytes {
|
||||
@@ -191,7 +184,7 @@ func (s *requestStream) ReadResponse() (*http.Response, error) {
|
||||
hfs, err := s.decoder.DecodeFull(headerBlock)
|
||||
if err != nil {
|
||||
// TODO: use the right error code
|
||||
s.closeConnection(ErrCodeGeneralProtocolError)
|
||||
s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeGeneralProtocolError), "")
|
||||
return nil, fmt.Errorf("http3: failed to decode response headers: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
"github.com/quic-go/quic-go"
|
||||
mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"go.uber.org/mock/gomock"
|
||||
@@ -30,15 +32,18 @@ var _ = Describe("Stream", func() {
|
||||
errorCbCalled bool
|
||||
)
|
||||
|
||||
errorCb := func(ErrCode) { errorCbCalled = true }
|
||||
|
||||
BeforeEach(func() {
|
||||
buf = &bytes.Buffer{}
|
||||
errorCbCalled = false
|
||||
qstr = mockquic.NewMockStream(mockCtrl)
|
||||
qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
|
||||
qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
str = newStream(qstr, errorCb)
|
||||
conn := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(qerr.ApplicationErrorCode, string) error {
|
||||
errorCbCalled = true
|
||||
return nil
|
||||
}).AnyTimes()
|
||||
str = newStream(qstr, conn)
|
||||
})
|
||||
|
||||
It("reads DATA frames in a single run", func() {
|
||||
@@ -167,7 +172,7 @@ var _ = Describe("length-limited streams", func() {
|
||||
qstr = mockquic.NewMockStream(mockCtrl)
|
||||
qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
|
||||
qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
str = newStream(qstr, func(ErrCode) { Fail("didn't expect error callback to be called") })
|
||||
str = newStream(qstr, mockquic.NewMockEarlyConnection(mockCtrl))
|
||||
})
|
||||
|
||||
It("reads all frames", func() {
|
||||
@@ -213,15 +218,14 @@ var _ = Describe("Request Stream", func() {
|
||||
BeforeEach(func() {
|
||||
qstr = mockquic.NewMockStream(mockCtrl)
|
||||
requestWriter := newRequestWriter(utils.DefaultLogger)
|
||||
conn := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
str = newRequestStream(
|
||||
newStream(qstr, func(code ErrCode) { Fail("errored") }),
|
||||
mockquic.NewMockEarlyConnection(mockCtrl),
|
||||
newStream(qstr, conn),
|
||||
requestWriter,
|
||||
make(chan struct{}),
|
||||
qpack.NewDecoder(func(qpack.HeaderField) {}),
|
||||
true,
|
||||
math.MaxUint64,
|
||||
func(code ErrCode) { Fail("errored") },
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
@@ -528,11 +528,10 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
|
||||
// Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
|
||||
// See section 4.1.2 of RFC 9114.
|
||||
var httpStr Stream
|
||||
closeConnection := func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") }
|
||||
if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
|
||||
httpStr = newLengthLimitedStream(newStream(str, closeConnection), req.ContentLength)
|
||||
httpStr = newLengthLimitedStream(newStream(str, conn), req.ContentLength)
|
||||
} else {
|
||||
httpStr = newStream(str, closeConnection)
|
||||
httpStr = newStream(str, conn)
|
||||
}
|
||||
body := newRequestBody(httpStr, conn.Context(), conn.ReceivedSettings(), conn.Settings)
|
||||
req.Body = body
|
||||
|
||||
Reference in New Issue
Block a user