diff --git a/h2quic/response_writer.go b/h2quic/response_writer.go index fc0d8a45..a730e34c 100644 --- a/h2quic/response_writer.go +++ b/h2quic/response_writer.go @@ -4,6 +4,7 @@ import ( "bytes" "net/http" "strconv" + "sync" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/utils" @@ -13,19 +14,22 @@ import ( type responseWriter struct { dataStreamID protocol.StreamID - headerStream utils.Stream dataStream utils.Stream + headerStream utils.Stream + headerStreamMutex *sync.Mutex + header http.Header headerWritten bool } -func newResponseWriter(headerStream, dataStream utils.Stream, dataStreamID protocol.StreamID) *responseWriter { +func newResponseWriter(headerStream utils.Stream, headerStreamMutex *sync.Mutex, dataStream utils.Stream, dataStreamID protocol.StreamID) *responseWriter { return &responseWriter{ - header: http.Header{}, - headerStream: headerStream, - dataStream: dataStream, - dataStreamID: dataStreamID, + header: http.Header{}, + headerStream: headerStream, + headerStreamMutex: headerStreamMutex, + dataStream: dataStream, + dataStreamID: dataStreamID, } } @@ -45,6 +49,8 @@ func (w *responseWriter) WriteHeader(status int) { } utils.Infof("Responding with %d", status) + w.headerStreamMutex.Lock() + defer w.headerStreamMutex.Unlock() h2framer := http2.NewFramer(w.headerStream, nil) err := h2framer.WriteHeaders(http2.HeadersFrameParam{ StreamID: uint32(w.dataStreamID), diff --git a/h2quic/response_writer_test.go b/h2quic/response_writer_test.go index 0a042585..0448c015 100644 --- a/h2quic/response_writer_test.go +++ b/h2quic/response_writer_test.go @@ -3,6 +3,7 @@ package h2quic import ( "bytes" "net/http" + "sync" "github.com/lucas-clemente/quic-go/protocol" . "github.com/onsi/ginkgo" @@ -29,7 +30,7 @@ var _ = Describe("Response Writer", func() { BeforeEach(func() { headerStream = &mockStream{} dataStream = &mockStream{} - w = newResponseWriter(headerStream, dataStream, 5) + w = newResponseWriter(headerStream, &sync.Mutex{}, dataStream, 5) }) It("writes status", func() { diff --git a/h2quic/server.go b/h2quic/server.go index eeb99874..94dddf76 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "net" "net/http" + "sync" "time" "github.com/lucas-clemente/quic-go" @@ -83,8 +84,9 @@ func (s *Server) handleStream(session streamCreator, stream utils.Stream) { h2framer := http2.NewFramer(nil, stream) go func() { + var headerStreamMutex sync.Mutex // Protects concurrent calls to Write() for { - if err := s.handleRequest(session, stream, hpackDecoder, h2framer); err != nil { + if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil { utils.Errorf("error handling h2 request: %s", err.Error()) return } @@ -92,7 +94,7 @@ func (s *Server) handleStream(session streamCreator, stream utils.Stream) { }() } -func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error { +func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error { h2frame, err := h2framer.ReadFrame() if err != nil { return err @@ -125,7 +127,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, // stream's Close() closes the write side, not the read side req.Body = ioutil.NopCloser(dataStream) - responseWriter := newResponseWriter(headerStream, dataStream, protocol.StreamID(h2headersFrame.StreamID)) + responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID)) go func() { handler := s.Handler diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 9fcd13bc..bc54c1bd 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -2,6 +2,7 @@ package h2quic import ( "net/http" + "sync" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" @@ -66,7 +67,7 @@ var _ = Describe("H2 server", func() { // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, }) - err := s.handleRequest(session, headerStream, hpackDecoder, h2framer) + err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) Expect(err).NotTo(HaveOccurred()) Eventually(func() bool { return handlerCalled }).Should(BeTrue()) Expect(dataStream.remoteClosed).To(BeTrue()) @@ -83,7 +84,7 @@ var _ = Describe("H2 server", func() { // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, }) - err := s.handleRequest(session, headerStream, hpackDecoder, h2framer) + err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) Expect(err).NotTo(HaveOccurred()) Eventually(func() bool { return handlerCalled }).Should(BeTrue()) Expect(dataStream.remoteClosed).To(BeFalse()) diff --git a/stream.go b/stream.go index fb46d82a..29e008e4 100644 --- a/stream.go +++ b/stream.go @@ -25,6 +25,8 @@ var ( ) // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface +// +// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. type stream struct { streamID protocol.StreamID session streamHandler