diff --git a/stream.go b/stream.go index 3e2a7303a..ce0ffdb28 100644 --- a/stream.go +++ b/stream.go @@ -43,10 +43,10 @@ type stream struct { frameQueue *streamFrameSorter readChan chan struct{} - dataForWriting []byte - finSent utils.AtomicBool - rstSent utils.AtomicBool - doneWritingOrErrCond sync.Cond + dataForWriting []byte + finSent utils.AtomicBool + rstSent utils.AtomicBool + writeChan chan struct{} flowControlManager flowcontrol.FlowControlManager } @@ -56,16 +56,15 @@ func newStream(StreamID protocol.StreamID, onData func(), onReset func(protocol.StreamID, protocol.ByteCount), flowControlManager flowcontrol.FlowControlManager) *stream { - s := &stream{ + return &stream{ onData: onData, onReset: onReset, streamID: StreamID, flowControlManager: flowControlManager, frameQueue: newStreamFrameSorter(), readChan: make(chan struct{}, 1), + writeChan: make(chan struct{}, 1), } - s.doneWritingOrErrCond.L = &s.mutex - return s } // Read implements io.Reader. It is not thread safe! @@ -147,30 +146,35 @@ func (s *stream) Read(p []byte) (int, error) { } func (s *stream) Write(p []byte) (int, error) { - if s.resetLocally.Get() { - return 0, s.err - } - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.err != nil { - return 0, s.err + if s.resetLocally.Get() || s.err != nil { + err := s.err + s.mutex.Unlock() + return 0, err } - if len(p) == 0 { + s.mutex.Unlock() return 0, nil } s.dataForWriting = make([]byte, len(p)) copy(s.dataForWriting, p) - s.onData() + s.mutex.Unlock() - for s.dataForWriting != nil && s.err == nil { - s.doneWritingOrErrCond.Wait() + for { + s.mutex.Lock() + if s.dataForWriting == nil || s.err != nil { + s.mutex.Unlock() + break + } + s.mutex.Unlock() + <-s.writeChan } + s.mutex.Lock() + defer s.mutex.Unlock() + if s.err != nil { return len(p) - len(s.dataForWriting), s.err } @@ -190,14 +194,12 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount { func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { s.mutex.Lock() - if s.err != nil { - s.mutex.Unlock() - return nil - } - if s.dataForWriting == nil { - s.mutex.Unlock() + defer s.mutex.Unlock() + + if s.err != nil || s.dataForWriting == nil { return nil } + var ret []byte if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { ret = s.dataForWriting[:maxBytes] @@ -205,10 +207,9 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { } else { ret = s.dataForWriting s.dataForWriting = nil - s.doneWritingOrErrCond.Signal() + s.signalWrite() } s.writeOffset += protocol.ByteCount(len(ret)) - s.mutex.Unlock() return ret } @@ -263,6 +264,14 @@ func (s *stream) signalRead() { } } +// signalRead performs a non-blocking send on the writeChan +func (s *stream) signalWrite() { + select { + case s.writeChan <- struct{}{}: + default: + } +} + // CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset func (s *stream) CloseRemote(offset protocol.ByteCount) { s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset}) @@ -277,7 +286,7 @@ func (s *stream) Cancel(err error) { if s.err == nil { s.err = err s.signalRead() - s.doneWritingOrErrCond.Signal() + s.signalWrite() } s.mutex.Unlock() } @@ -293,7 +302,7 @@ func (s *stream) Reset(err error) { if s.err == nil { s.err = err s.signalRead() - s.doneWritingOrErrCond.Signal() + s.signalWrite() } if s.shouldSendReset() { s.onReset(s.streamID, s.writeOffset) @@ -312,7 +321,7 @@ func (s *stream) RegisterRemoteError(err error) { // errors must not be changed! if s.err == nil { s.err = err - s.doneWritingOrErrCond.Signal() + s.signalWrite() } if s.shouldSendReset() { s.onReset(s.streamID, s.writeOffset)