implement a stream.Reset() method

ref #380
This commit is contained in:
Marten Seemann
2017-01-06 11:27:13 +07:00
parent b741724069
commit a86f31d789
5 changed files with 123 additions and 45 deletions

View File

@@ -25,12 +25,12 @@ type stream struct {
readOffset protocol.ByteCount
// Once set, the errors must not be changed!
readErr error
writeErr error
err error
cancelled utils.AtomicBool
finishedReading utils.AtomicBool
finishedWriting utils.AtomicBool
resetLocally utils.AtomicBool
frameQueue *streamFrameSorter
newFrameOrErrCond sync.Cond
@@ -59,8 +59,8 @@ func newStream(StreamID protocol.StreamID, onData func(), flowControlManager flo
// Read implements io.Reader. It is not thread safe!
func (s *stream) Read(p []byte) (int, error) {
if s.cancelled.Get() {
return 0, s.readErr
if s.cancelled.Get() || s.resetLocally.Get() {
return 0, s.err
}
if s.finishedReading.Get() {
return 0, io.EOF
@@ -73,14 +73,14 @@ func (s *stream) Read(p []byte) (int, error) {
if frame == nil && bytesRead > 0 {
s.mutex.Unlock()
return bytesRead, s.readErr
return bytesRead, s.err
}
var err error
for {
// Stop waiting on errors
if s.readErr != nil {
err = s.readErr
if s.err != nil {
err = s.err
break
}
if frame != nil {
@@ -134,11 +134,15 @@ func (s *stream) Read(p []byte) (int, error) {
}
func (s *stream) Write(p []byte) (int, error) {
if s.resetLocally.Get() {
return 0, s.err
}
s.mutex.Lock()
defer s.mutex.Unlock()
if s.writeErr != nil {
return 0, s.writeErr
if s.err != nil {
return 0, s.err
}
if len(p) == 0 {
@@ -150,12 +154,12 @@ func (s *stream) Write(p []byte) (int, error) {
s.onData()
for s.dataForWriting != nil && s.writeErr == nil {
for s.dataForWriting != nil && s.err == nil {
s.doneWritingOrErrCond.Wait()
}
if s.writeErr != nil {
return 0, s.writeErr
if s.err != nil {
return 0, s.err
}
return len(p), nil
@@ -164,7 +168,7 @@ func (s *stream) Write(p []byte) (int, error) {
func (s *stream) lenOfDataForWriting() protocol.ByteCount {
s.mutex.Lock()
var l protocol.ByteCount
if s.writeErr == nil {
if s.err == nil {
l = protocol.ByteCount(len(s.dataForWriting))
}
s.mutex.Unlock()
@@ -173,7 +177,7 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount {
func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
s.mutex.Lock()
if s.writeErr != nil {
if s.err != nil {
s.mutex.Unlock()
return nil
}
@@ -204,7 +208,7 @@ func (s *stream) Close() error {
func (s *stream) shouldSendFin() bool {
s.mutex.Lock()
res := s.finishedWriting.Get() && !s.finSent && s.writeErr == nil && s.dataForWriting == nil
res := s.finishedWriting.Get() && !s.finSent && s.err == nil && s.dataForWriting == nil
s.mutex.Unlock()
return res
}
@@ -247,12 +251,26 @@ func (s *stream) Cancel(err error) {
s.mutex.Lock()
// errors must not be changed!
if s.readErr == nil {
s.readErr = err
if s.err == nil {
s.err = err
s.newFrameOrErrCond.Signal()
s.doneWritingOrErrCond.Signal()
}
if s.writeErr == nil {
s.writeErr = err
// if s.writeErr == nil {
// s.writeErr = err
// }
s.mutex.Unlock()
}
// resets the stream locally
func (s *stream) Reset(err error) {
s.finishedReading.Set(true)
s.resetLocally.Set(true)
s.mutex.Lock()
// errors must not be changed!
if s.err == nil {
s.err = err
s.newFrameOrErrCond.Signal()
s.doneWritingOrErrCond.Signal()
}
s.mutex.Unlock()
@@ -263,8 +281,8 @@ func (s *stream) RegisterRemoteError(err error) {
s.finishedWriting.Set(true)
s.mutex.Lock()
// errors must not be changed!
if s.writeErr == nil {
s.writeErr = err
if s.err == nil {
s.err = err
s.doneWritingOrErrCond.Signal()
}
s.mutex.Unlock()
@@ -277,7 +295,7 @@ func (s *stream) finishedRead() bool {
func (s *stream) finishedWriteAndSentFin() bool {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.writeErr != nil || (s.finishedWriting.Get() && s.finSent)
return s.err != nil || (s.finishedWriting.Get() && s.finSent)
}
func (s *stream) finished() bool {