diff --git a/example/main.go b/example/main.go index 96c28cbb..875c8354 100644 --- a/example/main.go +++ b/example/main.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "crypto/x509" "flag" + "fmt" "io/ioutil" "net/http" @@ -36,6 +37,13 @@ func main() { } } + http.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + fmt.Printf("error reading body while handling /echo: %s\n", err.Error()) + } + w.Write(body) + }) http.Handle("/", http.FileServer(http.Dir(*www))) server, err := h2quic.NewServer(tlsConfig) diff --git a/h2quic/response_writer.go b/h2quic/response_writer.go index a2897869..2b7800bb 100644 --- a/h2quic/response_writer.go +++ b/h2quic/response_writer.go @@ -2,7 +2,6 @@ package h2quic import ( "bytes" - "fmt" "net/http" "strconv" @@ -12,12 +11,7 @@ import ( "golang.org/x/net/http2/hpack" ) -type streamCreator interface { - NewStream(protocol.StreamID) (utils.Stream, error) -} - type responseWriter struct { - session streamCreator dataStreamID protocol.StreamID headerStream utils.Stream dataStream utils.Stream @@ -26,12 +20,12 @@ type responseWriter struct { headerWritten bool } -func newResponseWriter(headerStream utils.Stream, dataStreamID protocol.StreamID, session streamCreator) *responseWriter { +func newResponseWriter(headerStream, dataStream utils.Stream, dataStreamID protocol.StreamID) *responseWriter { return &responseWriter{ header: http.Header{}, headerStream: headerStream, + dataStream: dataStream, dataStreamID: dataStreamID, - session: session, } } @@ -66,16 +60,5 @@ func (w *responseWriter) Write(p []byte) (int, error) { if !w.headerWritten { w.WriteHeader(200) } - - if len(p) != 0 { - if w.dataStream == nil { - var err error - w.dataStream, err = w.session.NewStream(w.dataStreamID) - if err != nil { - return 0, fmt.Errorf("error creating data stream: %s", err.Error()) - } - } - return w.dataStream.Write(p) - } - return 0, nil + return w.dataStream.Write(p) } diff --git a/h2quic/response_writer_test.go b/h2quic/response_writer_test.go index c3e39528..b622b2e4 100644 --- a/h2quic/response_writer_test.go +++ b/h2quic/response_writer_test.go @@ -4,8 +4,6 @@ import ( "bytes" "net/http" - "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -16,28 +14,17 @@ type mockStream struct { func (mockStream) Close() error { return nil } -type mockSession struct { - stream *mockStream -} - -func (s *mockSession) NewStream(id protocol.StreamID) (utils.Stream, error) { - Expect(id).To(Equal(protocol.StreamID(5))) - return s.stream, nil -} - var _ = Describe("Response Writer", func() { var ( w *responseWriter headerStream *mockStream dataStream *mockStream - s *mockSession ) BeforeEach(func() { headerStream = &mockStream{} dataStream = &mockStream{} - s = &mockSession{stream: dataStream} - w = newResponseWriter(headerStream, 5, s) + w = newResponseWriter(headerStream, dataStream, 5) }) It("writes status", func() { diff --git a/h2quic/server.go b/h2quic/server.go index c9457d7e..7b15a19f 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -3,6 +3,7 @@ package h2quic import ( "crypto/tls" "errors" + "io/ioutil" "net/http" "github.com/lucas-clemente/quic-go" @@ -12,6 +13,10 @@ import ( "golang.org/x/net/http2/hpack" ) +type streamCreator interface { + GetOrCreateStream(protocol.StreamID) (utils.Stream, error) +} + // Server is a HTTP2 server listening for QUIC connections type Server struct { server *quic.Server @@ -23,7 +28,7 @@ func NewServer(tlsConfig *tls.Config) (*Server, error) { s := &Server{} var err error - s.server, err = quic.NewServer(tlsConfig, s.handleStream) + s.server, err = quic.NewServer(tlsConfig, s.handleStreamCb) if err != nil { return nil, err } @@ -41,7 +46,11 @@ func (s *Server) ListenAndServe(addr string, handler http.Handler) error { return s.server.ListenAndServe(addr) } -func (s *Server) handleStream(session *quic.Session, headerStream utils.Stream) { +func (s *Server) handleStreamCb(session *quic.Session, headerStream utils.Stream) { + s.handleStream(session, headerStream) +} + +func (s *Server) handleStream(session streamCreator, headerStream utils.Stream) { hpackDecoder := hpack.NewDecoder(4096, nil) h2framer := http2.NewFramer(nil, headerStream) @@ -55,7 +64,7 @@ func (s *Server) handleStream(session *quic.Session, headerStream utils.Stream) }() } -func (s *Server) handleRequest(session *quic.Session, headerStream utils.Stream, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error { +func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error { h2frame, err := h2framer.ReadFrame() if err != nil { return err @@ -76,13 +85,16 @@ func (s *Server) handleRequest(session *quic.Session, headerStream utils.Stream, } utils.Infof("%s %s%s", req.Method, req.Host, req.RequestURI) - responseWriter := &responseWriter{ - header: http.Header{}, - headerStream: headerStream, - dataStreamID: protocol.StreamID(h2headersFrame.StreamID), - session: session, + dataStream, err := session.GetOrCreateStream(protocol.StreamID(h2headersFrame.StreamID)) + if err != nil { + return err } + // stream's Close() closes the write side, not the read side + req.Body = ioutil.NopCloser(dataStream) + + responseWriter := newResponseWriter(headerStream, dataStream, protocol.StreamID(h2headersFrame.StreamID)) + go func() { s.handler.ServeHTTP(responseWriter, req) if responseWriter.dataStream != nil { diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 0832428f..92519fe9 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -6,15 +6,24 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/testdata" + "github.com/lucas-clemente/quic-go/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) +type mockSession struct{} + +func (s *mockSession) GetOrCreateStream(id protocol.StreamID) (utils.Stream, error) { + return &mockStream{}, nil +} + var _ = Describe("Response Writer", func() { var ( - s *Server + s *Server + session *mockSession ) BeforeEach(func() { @@ -22,6 +31,7 @@ var _ = Describe("Response Writer", func() { s, err = NewServer(testdata.GetTLSConfig()) Expect(err).NotTo(HaveOccurred()) Expect(s).NotTo(BeNil()) + session = &mockSession{} }) It("uses default handler", func() { @@ -63,11 +73,10 @@ var _ = Describe("Response Writer", func() { // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, }) - err := s.handleRequest(nil, headerStream, hpackDecoder, h2framer) + err := s.handleRequest(session, headerStream, hpackDecoder, h2framer) Expect(err).NotTo(HaveOccurred()) Eventually(func() bool { return handlerCalled }).Should(BeTrue()) }) - }) It("handles the header stream", func() { @@ -82,7 +91,7 @@ var _ = Describe("Response Writer", func() { // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, }) - s.handleStream(nil, headerStream) + s.handleStream(session, headerStream) Eventually(func() bool { return handlerCalled }).Should(BeTrue()) }) }) diff --git a/session.go b/session.go index a3d2b23d..9bfc175e 100644 --- a/session.go +++ b/session.go @@ -526,12 +526,24 @@ func (s *Session) QueueStreamFrame(frame *frames.StreamFrame) error { func (s *Session) NewStream(id protocol.StreamID) (utils.Stream, error) { s.streamsMutex.Lock() defer s.streamsMutex.Unlock() - stream, err := newStream(s, s.connectionParametersManager, id) + return s.newStreamImpl(id) +} +// GetOrCreateStream returns an existing stream with the given id, or opens a new stream +func (s *Session) GetOrCreateStream(id protocol.StreamID) (utils.Stream, error) { + s.streamsMutex.Lock() + defer s.streamsMutex.Unlock() + if stream, ok := s.streams[id]; ok { + return stream, nil + } + return s.newStreamImpl(id) +} + +func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) { + stream, err := newStream(s, s.connectionParametersManager, id) if err != nil { return nil, err } - if s.streams[id] != nil { return nil, fmt.Errorf("Session: stream with ID %d already exists", id) }