From 683230372e051500a1c3af1ed66a8d8b06a1589e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 27 Mar 2020 08:42:35 +0700 Subject: [PATCH] use a buffered writer for the http3 response writer --- http3/client_test.go | 4 ++++ http3/response_writer.go | 15 +++++++++------ http3/response_writer_test.go | 1 + http3/server.go | 1 + http3/server_test.go | 8 ++++---- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/http3/client_test.go b/http3/client_test.go index 4f1e2296f..298d60afc 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -216,6 +216,7 @@ var _ = Describe("Client", func() { rspBuf := &bytes.Buffer{} rw := newResponseWriter(rspBuf, utils.DefaultLogger) rw.WriteHeader(418) + rw.Flush() gomock.InOrder( sess.EXPECT().HandshakeComplete().Return(handshakeCtx), @@ -383,6 +384,7 @@ var _ = Describe("Client", func() { rspBuf := &bytes.Buffer{} rw := newResponseWriter(rspBuf, utils.DefaultLogger) rw.WriteHeader(418) + rw.Flush() ctx, cancel := context.WithCancel(context.Background()) req := request.WithContext(ctx) @@ -455,6 +457,7 @@ var _ = Describe("Client", func() { gz := gzip.NewWriter(rw) gz.Write([]byte("gzipped response")) gz.Close() + rw.Flush() str.EXPECT().Write(gomock.Any()).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { return buf.Read(p) @@ -476,6 +479,7 @@ var _ = Describe("Client", func() { buf := &bytes.Buffer{} rw := newResponseWriter(buf, utils.DefaultLogger) rw.Write([]byte("not gzipped")) + rw.Flush() str.EXPECT().Write(gomock.Any()).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { return buf.Read(p) diff --git a/http3/response_writer.go b/http3/response_writer.go index a890f5a29..8f4d69604 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -1,6 +1,7 @@ package http3 import ( + "bufio" "bytes" "io" "net/http" @@ -12,7 +13,7 @@ import ( ) type responseWriter struct { - stream io.Writer + stream *bufio.Writer header http.Header status int // status code passed to WriteHeader @@ -22,11 +23,12 @@ type responseWriter struct { } var _ http.ResponseWriter = &responseWriter{} +var _ http.Flusher = &responseWriter{} func newResponseWriter(stream io.Writer, logger utils.Logger) *responseWriter { return &responseWriter{ header: http.Header{}, - stream: stream, + stream: bufio.NewWriter(stream), logger: logger, } } @@ -79,10 +81,11 @@ func (w *responseWriter) Write(p []byte) (int, error) { return w.stream.Write(p) } -func (w *responseWriter) Flush() {} - -// test that we implement http.Flusher -var _ http.Flusher = &responseWriter{} +func (w *responseWriter) Flush() { + if err := w.stream.Flush(); err != nil { + w.logger.Errorf("could not flush to stream: %s", err.Error()) + } +} // copied from http2/http2.go // bodyAllowedForStatus reports whether a given response status code diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index 8697da808..ab87e58c0 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -24,6 +24,7 @@ var _ = Describe("Response Writer", func() { }) decodeHeader := func(str io.Reader) map[string][]string { + rw.Flush() fields := make(map[string][]string) decoder := qpack.NewDecoder(nil) diff --git a/http3/server.go b/http3/server.go index 258c8d973..7d2abffbb 100644 --- a/http3/server.go +++ b/http3/server.go @@ -270,6 +270,7 @@ func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpac ctx = context.WithValue(ctx, http.LocalAddrContextKey, sess.LocalAddr()) req = req.WithContext(ctx) responseWriter := newResponseWriter(str, s.logger) + defer responseWriter.Flush() handler := s.Handler if handler == nil { handler = http.DefaultServeMux diff --git a/http3/server_test.go b/http3/server_test.go index 9eed586de..6f8e216bb 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -57,14 +57,14 @@ var _ = Describe("Server", func() { decoder := qpack.NewDecoder(nil) frame, err := parseNextFrame(str) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame) data := make([]byte, headersFrame.Length) _, err = io.ReadFull(str, data) - Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) hfs, err := decoder.DecodeFull(data) - Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) for _, p := range hfs { fields[p.Name] = append(fields[p.Name], p.Value) }