fix a race when multiple requests were handled at the same time

fixes #157
This commit is contained in:
Lucas Clemente
2016-05-31 19:05:33 +02:00
parent a7afffe700
commit 4b2ab55435
5 changed files with 24 additions and 12 deletions

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"net/http" "net/http"
"strconv" "strconv"
"sync"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/utils"
@@ -13,19 +14,22 @@ import (
type responseWriter struct { type responseWriter struct {
dataStreamID protocol.StreamID dataStreamID protocol.StreamID
headerStream utils.Stream
dataStream utils.Stream dataStream utils.Stream
headerStream utils.Stream
headerStreamMutex *sync.Mutex
header http.Header header http.Header
headerWritten bool headerWritten bool
} }
func newResponseWriter(headerStream, dataStream utils.Stream, dataStreamID protocol.StreamID) *responseWriter { func newResponseWriter(headerStream utils.Stream, headerStreamMutex *sync.Mutex, dataStream utils.Stream, dataStreamID protocol.StreamID) *responseWriter {
return &responseWriter{ return &responseWriter{
header: http.Header{}, header: http.Header{},
headerStream: headerStream, headerStream: headerStream,
dataStream: dataStream, headerStreamMutex: headerStreamMutex,
dataStreamID: dataStreamID, dataStream: dataStream,
dataStreamID: dataStreamID,
} }
} }
@@ -45,6 +49,8 @@ func (w *responseWriter) WriteHeader(status int) {
} }
utils.Infof("Responding with %d", status) utils.Infof("Responding with %d", status)
w.headerStreamMutex.Lock()
defer w.headerStreamMutex.Unlock()
h2framer := http2.NewFramer(w.headerStream, nil) h2framer := http2.NewFramer(w.headerStream, nil)
err := h2framer.WriteHeaders(http2.HeadersFrameParam{ err := h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: uint32(w.dataStreamID), StreamID: uint32(w.dataStreamID),

View File

@@ -3,6 +3,7 @@ package h2quic
import ( import (
"bytes" "bytes"
"net/http" "net/http"
"sync"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@@ -29,7 +30,7 @@ var _ = Describe("Response Writer", func() {
BeforeEach(func() { BeforeEach(func() {
headerStream = &mockStream{} headerStream = &mockStream{}
dataStream = &mockStream{} dataStream = &mockStream{}
w = newResponseWriter(headerStream, dataStream, 5) w = newResponseWriter(headerStream, &sync.Mutex{}, dataStream, 5)
}) })
It("writes status", func() { It("writes status", func() {

View File

@@ -7,6 +7,7 @@ import (
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
@@ -83,8 +84,9 @@ func (s *Server) handleStream(session streamCreator, stream utils.Stream) {
h2framer := http2.NewFramer(nil, stream) h2framer := http2.NewFramer(nil, stream)
go func() { go func() {
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
for { for {
if err := s.handleRequest(session, stream, hpackDecoder, h2framer); err != nil { if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
utils.Errorf("error handling h2 request: %s", err.Error()) utils.Errorf("error handling h2 request: %s", err.Error())
return return
} }
@@ -92,7 +94,7 @@ func (s *Server) handleStream(session streamCreator, stream utils.Stream) {
}() }()
} }
func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error { func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
h2frame, err := h2framer.ReadFrame() h2frame, err := h2framer.ReadFrame()
if err != nil { if err != nil {
return err return err
@@ -125,7 +127,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
// stream's Close() closes the write side, not the read side // stream's Close() closes the write side, not the read side
req.Body = ioutil.NopCloser(dataStream) req.Body = ioutil.NopCloser(dataStream)
responseWriter := newResponseWriter(headerStream, dataStream, protocol.StreamID(h2headersFrame.StreamID)) responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
go func() { go func() {
handler := s.Handler handler := s.Handler

View File

@@ -2,6 +2,7 @@ package h2quic
import ( import (
"net/http" "net/http"
"sync"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
@@ -66,7 +67,7 @@ var _ = Describe("H2 server", func() {
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
}) })
err := s.handleRequest(session, headerStream, hpackDecoder, h2framer) err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue()) Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Expect(dataStream.remoteClosed).To(BeTrue()) Expect(dataStream.remoteClosed).To(BeTrue())
@@ -83,7 +84,7 @@ var _ = Describe("H2 server", func() {
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
}) })
err := s.handleRequest(session, headerStream, hpackDecoder, h2framer) err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue()) Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Expect(dataStream.remoteClosed).To(BeFalse()) Expect(dataStream.remoteClosed).To(BeFalse())

View File

@@ -25,6 +25,8 @@ var (
) )
// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface
//
// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually.
type stream struct { type stream struct {
streamID protocol.StreamID streamID protocol.StreamID
session streamHandler session streamHandler