forked from quic-go/quic-go
replace the sync.Cond for stream.Write() by a channel
This commit is contained in:
69
stream.go
69
stream.go
@@ -43,10 +43,10 @@ type stream struct {
|
||||
frameQueue *streamFrameSorter
|
||||
readChan chan struct{}
|
||||
|
||||
dataForWriting []byte
|
||||
finSent utils.AtomicBool
|
||||
rstSent utils.AtomicBool
|
||||
doneWritingOrErrCond sync.Cond
|
||||
dataForWriting []byte
|
||||
finSent utils.AtomicBool
|
||||
rstSent utils.AtomicBool
|
||||
writeChan chan struct{}
|
||||
|
||||
flowControlManager flowcontrol.FlowControlManager
|
||||
}
|
||||
@@ -56,16 +56,15 @@ func newStream(StreamID protocol.StreamID,
|
||||
onData func(),
|
||||
onReset func(protocol.StreamID, protocol.ByteCount),
|
||||
flowControlManager flowcontrol.FlowControlManager) *stream {
|
||||
s := &stream{
|
||||
return &stream{
|
||||
onData: onData,
|
||||
onReset: onReset,
|
||||
streamID: StreamID,
|
||||
flowControlManager: flowControlManager,
|
||||
frameQueue: newStreamFrameSorter(),
|
||||
readChan: make(chan struct{}, 1),
|
||||
writeChan: make(chan struct{}, 1),
|
||||
}
|
||||
s.doneWritingOrErrCond.L = &s.mutex
|
||||
return s
|
||||
}
|
||||
|
||||
// Read implements io.Reader. It is not thread safe!
|
||||
@@ -147,30 +146,35 @@ func (s *stream) Read(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (s *stream) Write(p []byte) (int, error) {
|
||||
if s.resetLocally.Get() {
|
||||
return 0, s.err
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.err != nil {
|
||||
return 0, s.err
|
||||
if s.resetLocally.Get() || s.err != nil {
|
||||
err := s.err
|
||||
s.mutex.Unlock()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(p) == 0 {
|
||||
s.mutex.Unlock()
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
s.dataForWriting = make([]byte, len(p))
|
||||
copy(s.dataForWriting, p)
|
||||
|
||||
s.onData()
|
||||
s.mutex.Unlock()
|
||||
|
||||
for s.dataForWriting != nil && s.err == nil {
|
||||
s.doneWritingOrErrCond.Wait()
|
||||
for {
|
||||
s.mutex.Lock()
|
||||
if s.dataForWriting == nil || s.err != nil {
|
||||
s.mutex.Unlock()
|
||||
break
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
<-s.writeChan
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.err != nil {
|
||||
return len(p) - len(s.dataForWriting), s.err
|
||||
}
|
||||
@@ -190,14 +194,12 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount {
|
||||
|
||||
func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
|
||||
s.mutex.Lock()
|
||||
if s.err != nil {
|
||||
s.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
if s.dataForWriting == nil {
|
||||
s.mutex.Unlock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.err != nil || s.dataForWriting == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ret []byte
|
||||
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
|
||||
ret = s.dataForWriting[:maxBytes]
|
||||
@@ -205,10 +207,9 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
|
||||
} else {
|
||||
ret = s.dataForWriting
|
||||
s.dataForWriting = nil
|
||||
s.doneWritingOrErrCond.Signal()
|
||||
s.signalWrite()
|
||||
}
|
||||
s.writeOffset += protocol.ByteCount(len(ret))
|
||||
s.mutex.Unlock()
|
||||
return ret
|
||||
}
|
||||
|
||||
@@ -263,6 +264,14 @@ func (s *stream) signalRead() {
|
||||
}
|
||||
}
|
||||
|
||||
// signalRead performs a non-blocking send on the writeChan
|
||||
func (s *stream) signalWrite() {
|
||||
select {
|
||||
case s.writeChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset
|
||||
func (s *stream) CloseRemote(offset protocol.ByteCount) {
|
||||
s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset})
|
||||
@@ -277,7 +286,7 @@ func (s *stream) Cancel(err error) {
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
s.signalRead()
|
||||
s.doneWritingOrErrCond.Signal()
|
||||
s.signalWrite()
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
}
|
||||
@@ -293,7 +302,7 @@ func (s *stream) Reset(err error) {
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
s.signalRead()
|
||||
s.doneWritingOrErrCond.Signal()
|
||||
s.signalWrite()
|
||||
}
|
||||
if s.shouldSendReset() {
|
||||
s.onReset(s.streamID, s.writeOffset)
|
||||
@@ -312,7 +321,7 @@ func (s *stream) RegisterRemoteError(err error) {
|
||||
// errors must not be changed!
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
s.doneWritingOrErrCond.Signal()
|
||||
s.signalWrite()
|
||||
}
|
||||
if s.shouldSendReset() {
|
||||
s.onReset(s.streamID, s.writeOffset)
|
||||
|
||||
Reference in New Issue
Block a user