diff --git a/http3/client_test.go b/http3/client_test.go index 9dfe27f5d..7386e5f83 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -61,7 +61,7 @@ func encodeResponse(t *testing.T, status int) []byte { buf := &bytes.Buffer{} rstr := NewMockDatagramStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(newStream(rstr, nil, func(r io.Reader, u uint64) error { return nil }), nil, false, nil) + rw := newResponseWriter(newStream(rstr, nil, nil, func(r io.Reader, u uint64) error { return nil }), nil, false, nil) rw.WriteHeader(status) rw.Flush() return buf.Bytes() @@ -326,7 +326,7 @@ func testClient1xxHandling(t *testing.T, numEarlyHints int, terminalStatus int, var rspBuf bytes.Buffer rstr := NewMockDatagramStream(gomock.NewController(t)) rstr.EXPECT().Write(gomock.Any()).Do(rspBuf.Write).AnyTimes() - rw := newResponseWriter(newStream(rstr, nil, func(r io.Reader, u uint64) error { return nil }), nil, false, nil) + rw := newResponseWriter(newStream(rstr, nil, nil, func(r io.Reader, u uint64) error { return nil }), nil, false, nil) rw.header.Add("Link", "foo") rw.header.Add("Link", "bar") for range numEarlyHints { @@ -405,7 +405,7 @@ func testClientGzip(t *testing.T, var rspBuf bytes.Buffer rstr := NewMockDatagramStream(gomock.NewController(t)) rstr.EXPECT().Write(gomock.Any()).Do(rspBuf.Write).AnyTimes() - rw := newResponseWriter(newStream(rstr, nil, func(r io.Reader, u uint64) error { return nil }), nil, false, nil) + rw := newResponseWriter(newStream(rstr, nil, nil, func(r io.Reader, u uint64) error { return nil }), nil, false, nil) rw.WriteHeader(http.StatusOK) if responseAddContentEncoding { rw.header.Add("Content-Encoding", "gzip") diff --git a/http3/conn.go b/http3/conn.go index d1ae3400b..63ac9cbe2 100644 --- a/http3/conn.go +++ b/http3/conn.go @@ -165,7 +165,7 @@ func (c *Conn) openRequestStream( rsp := &http.Response{} trace := httptrace.ContextClientTrace(ctx) return newRequestStream( - newStream(hstr, c, func(r io.Reader, l uint64) error { + newStream(hstr, c, trace, func(r io.Reader, l uint64) error { hdr, err := c.decodeTrailers(r, l, maxHeaderBytes) if err != nil { return err @@ -179,7 +179,6 @@ func (c *Conn) openRequestStream( disableCompression, maxHeaderBytes, rsp, - trace, ), nil } diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index 07ebaad6d..0247b6cbc 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -61,7 +61,7 @@ func newTestResponseWriter(t *testing.T) *testResponseWriter { str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() str.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).AnyTimes() str.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).AnyTimes() - rw := newResponseWriter(newStream(str, nil, func(r io.Reader, u uint64) error { return nil }), nil, false, nil) + rw := newResponseWriter(newStream(str, nil, nil, func(r io.Reader, u uint64) error { return nil }), nil, false, nil) return &testResponseWriter{responseWriter: rw, buf: buf} } diff --git a/http3/server.go b/http3/server.go index e57944357..8d53a7246 100644 --- a/http3/server.go +++ b/http3/server.go @@ -610,7 +610,7 @@ func (s *Server) handleRequest(conn *Conn, str datagramStream, decoder *qpack.De if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 { contentLength = req.ContentLength } - hstr := newStream(str, conn, nil) + hstr := newStream(str, conn, nil, nil) body := newRequestBody(hstr, contentLength, conn.Context(), conn.ReceivedSettings(), conn.Settings) req.Body = body diff --git a/http3/stream.go b/http3/stream.go index 11b976796..b2a644753 100644 --- a/http3/stream.go +++ b/http3/stream.go @@ -35,7 +35,8 @@ type datagramStream interface { // When writing to and reading from the stream, data is framed in HTTP/3 DATA frames. type Stream struct { datagramStream - conn *Conn + conn *Conn + frameParser *frameParser buf []byte // used as a temporary buffer when writing the HTTP/3 frame headers @@ -45,12 +46,16 @@ type Stream struct { parsedTrailer bool } -func newStream(str datagramStream, conn *Conn, parseTrailer func(io.Reader, uint64) error) *Stream { +func newStream(str datagramStream, conn *Conn, trace *httptrace.ClientTrace, parseTrailer func(io.Reader, uint64) error) *Stream { return &Stream{ datagramStream: str, conn: conn, buf: make([]byte, 16), parseTrailer: parseTrailer, + frameParser: &frameParser{ + closeConn: conn.CloseWithError, + r: &tracingReader{Reader: str, trace: trace}, + }, } } @@ -151,12 +156,10 @@ type RequestStream struct { reqDone chan<- struct{} disableCompression bool response *http.Response - trace *httptrace.ClientTrace sentRequest bool requestedGzip bool isConnect bool - firstByte bool } func newRequestStream( @@ -167,7 +170,6 @@ func newRequestStream( disableCompression bool, maxHeaderBytes uint64, rsp *http.Response, - trace *httptrace.ClientTrace, ) *RequestStream { return &RequestStream{ str: str, @@ -177,7 +179,6 @@ func newRequestStream( disableCompression: disableCompression, maxHeaderBytes: maxHeaderBytes, response: rsp, - trace: trace, } } @@ -284,16 +285,7 @@ func (s *RequestStream) SendRequestHeader(req *http.Request) error { // It doesn't set Response.Request and Response.TLS. // It is invalid to call it after Read has been called. func (s *RequestStream) ReadResponse() (*http.Response, error) { - qstr := s.str.datagramStream - fp := &frameParser{ - closeConn: s.str.conn.CloseWithError, - r: &tracingReader{ - Reader: qstr, - first: &s.firstByte, - trace: s.trace, - }, - } - frame, err := fp.ParseNext() + frame, err := s.str.frameParser.ParseNext() if err != nil { s.str.CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) s.str.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) @@ -310,7 +302,7 @@ func (s *RequestStream) ReadResponse() (*http.Response, error) { return nil, fmt.Errorf("http3: HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes) } headerBlock := make([]byte, hf.Length) - if _, err := io.ReadFull(qstr, headerBlock); err != nil { + if _, err := io.ReadFull(s.str.datagramStream, headerBlock); err != nil { s.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) s.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) return nil, fmt.Errorf("http3: failed to read response headers: %w", err) @@ -354,15 +346,15 @@ func (s *RequestStream) ReadResponse() (*http.Response, error) { type tracingReader struct { io.Reader - first *bool - trace *httptrace.ClientTrace + readFirst bool + trace *httptrace.ClientTrace } func (r *tracingReader) Read(b []byte) (int, error) { n, err := r.Reader.Read(b) - if n > 0 && r.first != nil && !*r.first { + if n > 0 && !r.readFirst { traceGotFirstResponseByte(r.trace) - *r.first = true + r.readFirst = true } return n, err } diff --git a/http3/stream_test.go b/http3/stream_test.go index 64ab050d8..e0b423e3f 100644 --- a/http3/stream_test.go +++ b/http3/stream_test.go @@ -43,6 +43,7 @@ func TestStreamReadDataFrames(t *testing.T) { nil, 0, ), + nil, func(r io.Reader, u uint64) error { return nil }, ) @@ -91,6 +92,7 @@ func TestStreamInvalidFrame(t *testing.T) { str := newStream( qstr, newConnection(context.Background(), clientConn, false, protocol.PerspectiveClient, nil, 0), + nil, func(r io.Reader, u uint64) error { return nil }, ) @@ -112,7 +114,7 @@ func TestStreamWrite(t *testing.T) { mockCtrl := gomock.NewController(t) qstr := NewMockDatagramStream(mockCtrl) qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() - str := newStream(qstr, nil, func(r io.Reader, u uint64) error { return nil }) + str := newStream(qstr, nil, nil, func(r io.Reader, u uint64) error { return nil }) str.Write([]byte("foo")) str.Write([]byte("foobar")) @@ -144,6 +146,7 @@ func TestRequestStream(t *testing.T) { newStream( qstr, newConnection(context.Background(), clientConn, false, protocol.PerspectiveClient, nil, 0), + &httptrace.ClientTrace{}, func(r io.Reader, u uint64) error { return nil }, ), requestWriter, @@ -152,7 +155,6 @@ func TestRequestStream(t *testing.T) { true, math.MaxUint64, &http.Response{}, - &httptrace.ClientTrace{}, ) _, err := str.Read(make([]byte, 100))