replace the sync.Cond for stream.Write() by a channel

This commit is contained in:
Marten Seemann
2017-04-15 12:24:17 +07:00
parent 5fbd52158f
commit a70ae86f5a

View File

@@ -43,10 +43,10 @@ type stream struct {
frameQueue *streamFrameSorter
readChan chan struct{}
dataForWriting []byte
finSent utils.AtomicBool
rstSent utils.AtomicBool
doneWritingOrErrCond sync.Cond
dataForWriting []byte
finSent utils.AtomicBool
rstSent utils.AtomicBool
writeChan chan struct{}
flowControlManager flowcontrol.FlowControlManager
}
@@ -56,16 +56,15 @@ func newStream(StreamID protocol.StreamID,
onData func(),
onReset func(protocol.StreamID, protocol.ByteCount),
flowControlManager flowcontrol.FlowControlManager) *stream {
s := &stream{
return &stream{
onData: onData,
onReset: onReset,
streamID: StreamID,
flowControlManager: flowControlManager,
frameQueue: newStreamFrameSorter(),
readChan: make(chan struct{}, 1),
writeChan: make(chan struct{}, 1),
}
s.doneWritingOrErrCond.L = &s.mutex
return s
}
// Read implements io.Reader. It is not thread safe!
@@ -147,30 +146,35 @@ func (s *stream) Read(p []byte) (int, error) {
}
func (s *stream) Write(p []byte) (int, error) {
if s.resetLocally.Get() {
return 0, s.err
}
s.mutex.Lock()
defer s.mutex.Unlock()
if s.err != nil {
return 0, s.err
if s.resetLocally.Get() || s.err != nil {
err := s.err
s.mutex.Unlock()
return 0, err
}
if len(p) == 0 {
s.mutex.Unlock()
return 0, nil
}
s.dataForWriting = make([]byte, len(p))
copy(s.dataForWriting, p)
s.onData()
s.mutex.Unlock()
for s.dataForWriting != nil && s.err == nil {
s.doneWritingOrErrCond.Wait()
for {
s.mutex.Lock()
if s.dataForWriting == nil || s.err != nil {
s.mutex.Unlock()
break
}
s.mutex.Unlock()
<-s.writeChan
}
s.mutex.Lock()
defer s.mutex.Unlock()
if s.err != nil {
return len(p) - len(s.dataForWriting), s.err
}
@@ -190,14 +194,12 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount {
func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
s.mutex.Lock()
if s.err != nil {
s.mutex.Unlock()
return nil
}
if s.dataForWriting == nil {
s.mutex.Unlock()
defer s.mutex.Unlock()
if s.err != nil || s.dataForWriting == nil {
return nil
}
var ret []byte
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
ret = s.dataForWriting[:maxBytes]
@@ -205,10 +207,9 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
} else {
ret = s.dataForWriting
s.dataForWriting = nil
s.doneWritingOrErrCond.Signal()
s.signalWrite()
}
s.writeOffset += protocol.ByteCount(len(ret))
s.mutex.Unlock()
return ret
}
@@ -263,6 +264,14 @@ func (s *stream) signalRead() {
}
}
// signalRead performs a non-blocking send on the writeChan
func (s *stream) signalWrite() {
select {
case s.writeChan <- struct{}{}:
default:
}
}
// CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset
func (s *stream) CloseRemote(offset protocol.ByteCount) {
s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset})
@@ -277,7 +286,7 @@ func (s *stream) Cancel(err error) {
if s.err == nil {
s.err = err
s.signalRead()
s.doneWritingOrErrCond.Signal()
s.signalWrite()
}
s.mutex.Unlock()
}
@@ -293,7 +302,7 @@ func (s *stream) Reset(err error) {
if s.err == nil {
s.err = err
s.signalRead()
s.doneWritingOrErrCond.Signal()
s.signalWrite()
}
if s.shouldSendReset() {
s.onReset(s.streamID, s.writeOffset)
@@ -312,7 +321,7 @@ func (s *stream) RegisterRemoteError(err error) {
// errors must not be changed!
if s.err == nil {
s.err = err
s.doneWritingOrErrCond.Signal()
s.signalWrite()
}
if s.shouldSendReset() {
s.onReset(s.streamID, s.writeOffset)