diff --git a/http3/client_test.go b/http3/client_test.go index c4ba168cc..3f79e9a7d 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -26,7 +26,7 @@ func encodeResponse(status int) []byte { buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger) + rw := newResponseWriter(newStream(rstr, nil), nil, false, utils.DefaultLogger) rw.WriteHeader(status) rw.Flush() return buf.Bytes() @@ -738,7 +738,7 @@ var _ = Describe("Client", func() { buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger) + rw := newResponseWriter(newStream(rstr, nil), nil, false, utils.DefaultLogger) rw.Header().Set("Content-Encoding", "gzip") gz := gzip.NewWriter(rw) gz.Write([]byte("gzipped response")) @@ -764,7 +764,7 @@ var _ = Describe("Client", func() { buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger) + rw := newResponseWriter(newStream(rstr, nil), nil, false, utils.DefaultLogger) rw.Write([]byte("not gzipped")) rw.Flush() str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) diff --git a/http3/http_stream.go b/http3/http_stream.go index cd8214fbb..32792a4f6 100644 --- a/http3/http_stream.go +++ b/http3/http_stream.go @@ -101,6 +101,10 @@ func (s *stream) Write(b []byte) (int, error) { return s.Stream.Write(b) } +func (s *stream) writeUnframed(b []byte) (int, error) { + return s.Stream.Write(b) +} + func (s *stream) StreamID() protocol.StreamID { return s.Stream.StreamID() } diff --git a/http3/response_writer.go b/http3/response_writer.go index ba8b34db0..1699b1b96 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -8,7 +8,6 @@ import ( "strings" "time" - "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/qpack" @@ -18,46 +17,15 @@ import ( // The frame has a type and length field, both QUIC varints (maximum 8 bytes in length) const frameHeaderLen = 16 -// headerWriter wraps the stream, so that the first Write call flushes the header to the stream -type headerWriter struct { - str quic.Stream - header http.Header - status int // status code passed to WriteHeader - - logger utils.Logger -} - -// writeHeader encodes and flush header to the stream -func (hw *headerWriter) writeHeader() error { - var headers bytes.Buffer - enc := qpack.NewEncoder(&headers) - if err := enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)}); err != nil { - return err - } - - for k, v := range hw.header { - for index := range v { - if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil { - return err - } - } - } - - buf := make([]byte, 0, frameHeaderLen+headers.Len()) - buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf) - hw.logger.Infof("Responding with %d", hw.status) - buf = append(buf, headers.Bytes()...) - - _, err := hw.str.Write(buf) - return err -} - const maxSmallResponseSize = 4096 type responseWriter struct { - *headerWriter - conn Connection - buf []byte + str *stream + + conn Connection + header http.Header + buf []byte + status int // status code passed to WriteHeader // for responses smaller than maxSmallResponseSize, we buffer calls to Write, // and automatically add the Content-Length header @@ -68,6 +36,8 @@ type responseWriter struct { headerComplete bool // set once WriteHeader is called with a status code >= 200 headerWritten bool // set once the response header has been serialized to the stream isHead bool + + logger utils.Logger } var ( @@ -76,17 +46,14 @@ var ( _ Hijacker = &responseWriter{} ) -func newResponseWriter(str quic.Stream, conn Connection, isHead bool, logger utils.Logger) *responseWriter { - hw := &headerWriter{ - str: str, - header: http.Header{}, - logger: logger, - } +func newResponseWriter(str *stream, conn Connection, isHead bool, logger utils.Logger) *responseWriter { return &responseWriter{ - headerWriter: hw, - buf: make([]byte, frameHeaderLen), - conn: conn, - isHead: isHead, + str: str, + conn: conn, + header: http.Header{}, + buf: make([]byte, frameHeaderLen), + isHead: isHead, + logger: logger, } } @@ -107,7 +74,7 @@ func (w *responseWriter) WriteHeader(status int) { // immediately write 1xx headers if status < 200 { - w.writeHeader() + w.writeHeader(status) return } @@ -178,7 +145,7 @@ func (w *responseWriter) Write(p []byte) (int, error) { func (w *responseWriter) doWrite(p []byte) (int, error) { if !w.headerWritten { - if err := w.writeHeader(); err != nil { + if err := w.writeHeader(w.status); err != nil { return 0, maybeReplaceError(err) } w.headerWritten = true @@ -191,11 +158,11 @@ func (w *responseWriter) doWrite(p []byte) (int, error) { df := &dataFrame{Length: l} w.buf = w.buf[:0] w.buf = df.Append(w.buf) - if _, err := w.str.Write(w.buf); err != nil { + if _, err := w.str.writeUnframed(w.buf); err != nil { return 0, maybeReplaceError(err) } if len(w.smallResponseBuf) > 0 { - if _, err := w.str.Write(w.smallResponseBuf); err != nil { + if _, err := w.str.writeUnframed(w.smallResponseBuf); err != nil { return 0, maybeReplaceError(err) } w.smallResponseBuf = nil @@ -203,7 +170,7 @@ func (w *responseWriter) doWrite(p []byte) (int, error) { var n int if len(p) > 0 { var err error - n, err = w.str.Write(p) + n, err = w.str.writeUnframed(p) if err != nil { return n, maybeReplaceError(err) } @@ -211,6 +178,29 @@ func (w *responseWriter) doWrite(p []byte) (int, error) { return n, nil } +func (w *responseWriter) writeHeader(status int) error { + var headers bytes.Buffer + enc := qpack.NewEncoder(&headers) + if err := enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}); err != nil { + return err + } + + for k, v := range w.header { + for index := range v { + if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil { + return err + } + } + } + + buf := make([]byte, 0, frameHeaderLen+headers.Len()) + buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf) + buf = append(buf, headers.Bytes()...) + + _, err := w.str.writeUnframed(buf) + return err +} + func (w *responseWriter) FlushError() error { if !w.headerComplete { w.WriteHeader(http.StatusOK) diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index 3254b6e2c..011bf876d 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -28,7 +28,7 @@ var _ = Describe("Response Writer", func() { str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() str.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).AnyTimes() str.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).AnyTimes() - rw = newResponseWriter(str, nil, false, utils.DefaultLogger) + rw = newResponseWriter(newStream(str, nil), nil, false, utils.DefaultLogger) }) decodeHeader := func(str io.Reader) map[string][]string { diff --git a/http3/server.go b/http3/server.go index ec7a979f3..2f26eb817 100644 --- a/http3/server.go +++ b/http3/server.go @@ -531,7 +531,8 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 { contentLength = req.ContentLength } - body := newRequestBody(newStream(str, conn), contentLength, conn.Context(), conn.ReceivedSettings(), conn.Settings) + hstr := newStream(str, conn) + body := newRequestBody(hstr, contentLength, conn.Context(), conn.ReceivedSettings(), conn.Settings) req.Body = body if s.logger.Debug() { @@ -551,7 +552,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack } } req = req.WithContext(ctx) - r := newResponseWriter(str, conn, req.Method == http.MethodHead, s.logger) + r := newResponseWriter(hstr, conn, req.Method == http.MethodHead, s.logger) handler := s.Handler if handler == nil { handler = http.DefaultServeMux