forked from quic-go/quic-go
replace the sync.Cond for stream.Write() by a channel
This commit is contained in:
69
stream.go
69
stream.go
@@ -43,10 +43,10 @@ type stream struct {
|
|||||||
frameQueue *streamFrameSorter
|
frameQueue *streamFrameSorter
|
||||||
readChan chan struct{}
|
readChan chan struct{}
|
||||||
|
|
||||||
dataForWriting []byte
|
dataForWriting []byte
|
||||||
finSent utils.AtomicBool
|
finSent utils.AtomicBool
|
||||||
rstSent utils.AtomicBool
|
rstSent utils.AtomicBool
|
||||||
doneWritingOrErrCond sync.Cond
|
writeChan chan struct{}
|
||||||
|
|
||||||
flowControlManager flowcontrol.FlowControlManager
|
flowControlManager flowcontrol.FlowControlManager
|
||||||
}
|
}
|
||||||
@@ -56,16 +56,15 @@ func newStream(StreamID protocol.StreamID,
|
|||||||
onData func(),
|
onData func(),
|
||||||
onReset func(protocol.StreamID, protocol.ByteCount),
|
onReset func(protocol.StreamID, protocol.ByteCount),
|
||||||
flowControlManager flowcontrol.FlowControlManager) *stream {
|
flowControlManager flowcontrol.FlowControlManager) *stream {
|
||||||
s := &stream{
|
return &stream{
|
||||||
onData: onData,
|
onData: onData,
|
||||||
onReset: onReset,
|
onReset: onReset,
|
||||||
streamID: StreamID,
|
streamID: StreamID,
|
||||||
flowControlManager: flowControlManager,
|
flowControlManager: flowControlManager,
|
||||||
frameQueue: newStreamFrameSorter(),
|
frameQueue: newStreamFrameSorter(),
|
||||||
readChan: make(chan struct{}, 1),
|
readChan: make(chan struct{}, 1),
|
||||||
|
writeChan: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
s.doneWritingOrErrCond.L = &s.mutex
|
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read implements io.Reader. It is not thread safe!
|
// Read implements io.Reader. It is not thread safe!
|
||||||
@@ -147,30 +146,35 @@ func (s *stream) Read(p []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *stream) Write(p []byte) (int, error) {
|
func (s *stream) Write(p []byte) (int, error) {
|
||||||
if s.resetLocally.Get() {
|
|
||||||
return 0, s.err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
if s.resetLocally.Get() || s.err != nil {
|
||||||
|
err := s.err
|
||||||
if s.err != nil {
|
s.mutex.Unlock()
|
||||||
return 0, s.err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(p) == 0 {
|
if len(p) == 0 {
|
||||||
|
s.mutex.Unlock()
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
s.dataForWriting = make([]byte, len(p))
|
s.dataForWriting = make([]byte, len(p))
|
||||||
copy(s.dataForWriting, p)
|
copy(s.dataForWriting, p)
|
||||||
|
|
||||||
s.onData()
|
s.onData()
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
for s.dataForWriting != nil && s.err == nil {
|
for {
|
||||||
s.doneWritingOrErrCond.Wait()
|
s.mutex.Lock()
|
||||||
|
if s.dataForWriting == nil || s.err != nil {
|
||||||
|
s.mutex.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
s.mutex.Unlock()
|
||||||
|
<-s.writeChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
if s.err != nil {
|
if s.err != nil {
|
||||||
return len(p) - len(s.dataForWriting), s.err
|
return len(p) - len(s.dataForWriting), s.err
|
||||||
}
|
}
|
||||||
@@ -190,14 +194,12 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount {
|
|||||||
|
|
||||||
func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
|
func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
if s.err != nil {
|
defer s.mutex.Unlock()
|
||||||
s.mutex.Unlock()
|
|
||||||
return nil
|
if s.err != nil || s.dataForWriting == nil {
|
||||||
}
|
|
||||||
if s.dataForWriting == nil {
|
|
||||||
s.mutex.Unlock()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ret []byte
|
var ret []byte
|
||||||
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
|
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
|
||||||
ret = s.dataForWriting[:maxBytes]
|
ret = s.dataForWriting[:maxBytes]
|
||||||
@@ -205,10 +207,9 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
|
|||||||
} else {
|
} else {
|
||||||
ret = s.dataForWriting
|
ret = s.dataForWriting
|
||||||
s.dataForWriting = nil
|
s.dataForWriting = nil
|
||||||
s.doneWritingOrErrCond.Signal()
|
s.signalWrite()
|
||||||
}
|
}
|
||||||
s.writeOffset += protocol.ByteCount(len(ret))
|
s.writeOffset += protocol.ByteCount(len(ret))
|
||||||
s.mutex.Unlock()
|
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -263,6 +264,14 @@ func (s *stream) signalRead() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// signalRead performs a non-blocking send on the writeChan
|
||||||
|
func (s *stream) signalWrite() {
|
||||||
|
select {
|
||||||
|
case s.writeChan <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset
|
// CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset
|
||||||
func (s *stream) CloseRemote(offset protocol.ByteCount) {
|
func (s *stream) CloseRemote(offset protocol.ByteCount) {
|
||||||
s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset})
|
s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset})
|
||||||
@@ -277,7 +286,7 @@ func (s *stream) Cancel(err error) {
|
|||||||
if s.err == nil {
|
if s.err == nil {
|
||||||
s.err = err
|
s.err = err
|
||||||
s.signalRead()
|
s.signalRead()
|
||||||
s.doneWritingOrErrCond.Signal()
|
s.signalWrite()
|
||||||
}
|
}
|
||||||
s.mutex.Unlock()
|
s.mutex.Unlock()
|
||||||
}
|
}
|
||||||
@@ -293,7 +302,7 @@ func (s *stream) Reset(err error) {
|
|||||||
if s.err == nil {
|
if s.err == nil {
|
||||||
s.err = err
|
s.err = err
|
||||||
s.signalRead()
|
s.signalRead()
|
||||||
s.doneWritingOrErrCond.Signal()
|
s.signalWrite()
|
||||||
}
|
}
|
||||||
if s.shouldSendReset() {
|
if s.shouldSendReset() {
|
||||||
s.onReset(s.streamID, s.writeOffset)
|
s.onReset(s.streamID, s.writeOffset)
|
||||||
@@ -312,7 +321,7 @@ func (s *stream) RegisterRemoteError(err error) {
|
|||||||
// errors must not be changed!
|
// errors must not be changed!
|
||||||
if s.err == nil {
|
if s.err == nil {
|
||||||
s.err = err
|
s.err = err
|
||||||
s.doneWritingOrErrCond.Signal()
|
s.signalWrite()
|
||||||
}
|
}
|
||||||
if s.shouldSendReset() {
|
if s.shouldSendReset() {
|
||||||
s.onReset(s.streamID, s.writeOffset)
|
s.onReset(s.streamID, s.writeOffset)
|
||||||
|
|||||||
Reference in New Issue
Block a user