replace the sync.Cond for stream.Write() by a channel

This commit is contained in:
Marten Seemann
2017-04-15 12:24:17 +07:00
parent 5fbd52158f
commit a70ae86f5a

View File

@@ -43,10 +43,10 @@ type stream struct {
frameQueue *streamFrameSorter frameQueue *streamFrameSorter
readChan chan struct{} readChan chan struct{}
dataForWriting []byte dataForWriting []byte
finSent utils.AtomicBool finSent utils.AtomicBool
rstSent utils.AtomicBool rstSent utils.AtomicBool
doneWritingOrErrCond sync.Cond writeChan chan struct{}
flowControlManager flowcontrol.FlowControlManager flowControlManager flowcontrol.FlowControlManager
} }
@@ -56,16 +56,15 @@ func newStream(StreamID protocol.StreamID,
onData func(), onData func(),
onReset func(protocol.StreamID, protocol.ByteCount), onReset func(protocol.StreamID, protocol.ByteCount),
flowControlManager flowcontrol.FlowControlManager) *stream { flowControlManager flowcontrol.FlowControlManager) *stream {
s := &stream{ return &stream{
onData: onData, onData: onData,
onReset: onReset, onReset: onReset,
streamID: StreamID, streamID: StreamID,
flowControlManager: flowControlManager, flowControlManager: flowControlManager,
frameQueue: newStreamFrameSorter(), frameQueue: newStreamFrameSorter(),
readChan: make(chan struct{}, 1), readChan: make(chan struct{}, 1),
writeChan: make(chan struct{}, 1),
} }
s.doneWritingOrErrCond.L = &s.mutex
return s
} }
// Read implements io.Reader. It is not thread safe! // Read implements io.Reader. It is not thread safe!
@@ -147,30 +146,35 @@ func (s *stream) Read(p []byte) (int, error) {
} }
func (s *stream) Write(p []byte) (int, error) { func (s *stream) Write(p []byte) (int, error) {
if s.resetLocally.Get() {
return 0, s.err
}
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() if s.resetLocally.Get() || s.err != nil {
err := s.err
if s.err != nil { s.mutex.Unlock()
return 0, s.err return 0, err
} }
if len(p) == 0 { if len(p) == 0 {
s.mutex.Unlock()
return 0, nil return 0, nil
} }
s.dataForWriting = make([]byte, len(p)) s.dataForWriting = make([]byte, len(p))
copy(s.dataForWriting, p) copy(s.dataForWriting, p)
s.onData() s.onData()
s.mutex.Unlock()
for s.dataForWriting != nil && s.err == nil { for {
s.doneWritingOrErrCond.Wait() s.mutex.Lock()
if s.dataForWriting == nil || s.err != nil {
s.mutex.Unlock()
break
}
s.mutex.Unlock()
<-s.writeChan
} }
s.mutex.Lock()
defer s.mutex.Unlock()
if s.err != nil { if s.err != nil {
return len(p) - len(s.dataForWriting), s.err return len(p) - len(s.dataForWriting), s.err
} }
@@ -190,14 +194,12 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount {
func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
s.mutex.Lock() s.mutex.Lock()
if s.err != nil { defer s.mutex.Unlock()
s.mutex.Unlock()
return nil if s.err != nil || s.dataForWriting == nil {
}
if s.dataForWriting == nil {
s.mutex.Unlock()
return nil return nil
} }
var ret []byte var ret []byte
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
ret = s.dataForWriting[:maxBytes] ret = s.dataForWriting[:maxBytes]
@@ -205,10 +207,9 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
} else { } else {
ret = s.dataForWriting ret = s.dataForWriting
s.dataForWriting = nil s.dataForWriting = nil
s.doneWritingOrErrCond.Signal() s.signalWrite()
} }
s.writeOffset += protocol.ByteCount(len(ret)) s.writeOffset += protocol.ByteCount(len(ret))
s.mutex.Unlock()
return ret return ret
} }
@@ -263,6 +264,14 @@ func (s *stream) signalRead() {
} }
} }
// signalRead performs a non-blocking send on the writeChan
func (s *stream) signalWrite() {
select {
case s.writeChan <- struct{}{}:
default:
}
}
// CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset // CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset
func (s *stream) CloseRemote(offset protocol.ByteCount) { func (s *stream) CloseRemote(offset protocol.ByteCount) {
s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset}) s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset})
@@ -277,7 +286,7 @@ func (s *stream) Cancel(err error) {
if s.err == nil { if s.err == nil {
s.err = err s.err = err
s.signalRead() s.signalRead()
s.doneWritingOrErrCond.Signal() s.signalWrite()
} }
s.mutex.Unlock() s.mutex.Unlock()
} }
@@ -293,7 +302,7 @@ func (s *stream) Reset(err error) {
if s.err == nil { if s.err == nil {
s.err = err s.err = err
s.signalRead() s.signalRead()
s.doneWritingOrErrCond.Signal() s.signalWrite()
} }
if s.shouldSendReset() { if s.shouldSendReset() {
s.onReset(s.streamID, s.writeOffset) s.onReset(s.streamID, s.writeOffset)
@@ -312,7 +321,7 @@ func (s *stream) RegisterRemoteError(err error) {
// errors must not be changed! // errors must not be changed!
if s.err == nil { if s.err == nil {
s.err = err s.err = err
s.doneWritingOrErrCond.Signal() s.signalWrite()
} }
if s.shouldSendReset() { if s.shouldSendReset() {
s.onReset(s.streamID, s.writeOffset) s.onReset(s.streamID, s.writeOffset)