diff --git a/h2quic/response_writer.go b/h2quic/response_writer.go index 3499b4e1b..002a27458 100644 --- a/h2quic/response_writer.go +++ b/h2quic/response_writer.go @@ -6,15 +6,18 @@ import ( "net/http" "strconv" - "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/utils" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" ) +type streamCreator interface { + NewStream(protocol.StreamID) (utils.Stream, error) +} + type responseWriter struct { - session *quic.Session + session streamCreator dataStreamID protocol.StreamID headerStream utils.Stream dataStream utils.Stream @@ -23,7 +26,7 @@ type responseWriter struct { headerWritten bool } -func newResponseWriter(headerStream utils.Stream, dataStreamID protocol.StreamID, session *quic.Session) *responseWriter { +func newResponseWriter(headerStream utils.Stream, dataStreamID protocol.StreamID, session streamCreator) *responseWriter { return &responseWriter{ header: http.Header{}, headerStream: headerStream, @@ -59,7 +62,6 @@ func (w *responseWriter) WriteHeader(status int) { } } -// TODO: Test func (w *responseWriter) Write(p []byte) (int, error) { if !w.headerWritten { w.WriteHeader(200) diff --git a/h2quic/response_writer_test.go b/h2quic/response_writer_test.go index 2c6e69cce..c3e395286 100644 --- a/h2quic/response_writer_test.go +++ b/h2quic/response_writer_test.go @@ -4,6 +4,8 @@ import ( "bytes" "net/http" + "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -14,15 +16,28 @@ type mockStream struct { func (mockStream) Close() error { return nil } +type mockSession struct { + stream *mockStream +} + +func (s *mockSession) NewStream(id protocol.StreamID) (utils.Stream, error) { + Expect(id).To(Equal(protocol.StreamID(5))) + return s.stream, nil +} + var _ = Describe("Response Writer", func() { var ( w *responseWriter headerStream *mockStream + dataStream *mockStream + s *mockSession ) BeforeEach(func() { headerStream = &mockStream{} - w = newResponseWriter(headerStream, 5, nil) + dataStream = &mockStream{} + s = &mockSession{stream: dataStream} + w = newResponseWriter(headerStream, 5, s) }) It("writes status", func() { @@ -41,4 +56,33 @@ var _ = Describe("Response Writer", func() { 0x2, 0x34, 0x32, })) }) + + It("writes data", func() { + n, err := w.Write([]byte("foobar")) + Expect(n).To(Equal(6)) + Expect(err).ToNot(HaveOccurred()) + // Should have written 200 on the header stream + Expect(headerStream.Bytes()).To(Equal([]byte{ + 0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x88, + })) + // And foobar on the data stream + Expect(dataStream.Bytes()).To(Equal([]byte{ + 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72, + })) + }) + + It("writes data after WriteHeader is called", func() { + w.WriteHeader(http.StatusTeapot) + n, err := w.Write([]byte("foobar")) + Expect(n).To(Equal(6)) + Expect(err).ToNot(HaveOccurred()) + // Should have written 418 on the header stream + Expect(headerStream.Bytes()).To(Equal([]byte{ + 0x0, 0x0, 0x5, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 'H', 0x3, '4', '1', '8', + })) + // And foobar on the data stream + Expect(dataStream.Bytes()).To(Equal([]byte{ + 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72, + })) + }) })