implement a stream.Reset() method

ref #380
This commit is contained in:
Marten Seemann
2017-01-06 11:27:13 +07:00
parent b741724069
commit a86f31d789
5 changed files with 123 additions and 45 deletions

View File

@@ -17,6 +17,7 @@ type mockStream struct {
}
func (mockStream) Close() error { return nil }
func (mockStream) Reset(error) { panic("not implemented") }
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true }
func (s mockStream) StreamID() protocol.StreamID { return s.id }

View File

@@ -109,6 +109,7 @@ func (s *mockStream) Write(p []byte) (int, error) {
}
func (s *mockStream) Close() error { panic("not implemented") }
func (s *mockStream) Reset(error) { panic("not implemented") }
func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemented") }
func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") }

View File

@@ -25,12 +25,12 @@ type stream struct {
readOffset protocol.ByteCount
// Once set, the errors must not be changed!
readErr error
writeErr error
err error
cancelled utils.AtomicBool
finishedReading utils.AtomicBool
finishedWriting utils.AtomicBool
resetLocally utils.AtomicBool
frameQueue *streamFrameSorter
newFrameOrErrCond sync.Cond
@@ -59,8 +59,8 @@ func newStream(StreamID protocol.StreamID, onData func(), flowControlManager flo
// Read implements io.Reader. It is not thread safe!
func (s *stream) Read(p []byte) (int, error) {
if s.cancelled.Get() {
return 0, s.readErr
if s.cancelled.Get() || s.resetLocally.Get() {
return 0, s.err
}
if s.finishedReading.Get() {
return 0, io.EOF
@@ -73,14 +73,14 @@ func (s *stream) Read(p []byte) (int, error) {
if frame == nil && bytesRead > 0 {
s.mutex.Unlock()
return bytesRead, s.readErr
return bytesRead, s.err
}
var err error
for {
// Stop waiting on errors
if s.readErr != nil {
err = s.readErr
if s.err != nil {
err = s.err
break
}
if frame != nil {
@@ -134,11 +134,15 @@ func (s *stream) Read(p []byte) (int, error) {
}
func (s *stream) Write(p []byte) (int, error) {
if s.resetLocally.Get() {
return 0, s.err
}
s.mutex.Lock()
defer s.mutex.Unlock()
if s.writeErr != nil {
return 0, s.writeErr
if s.err != nil {
return 0, s.err
}
if len(p) == 0 {
@@ -150,12 +154,12 @@ func (s *stream) Write(p []byte) (int, error) {
s.onData()
for s.dataForWriting != nil && s.writeErr == nil {
for s.dataForWriting != nil && s.err == nil {
s.doneWritingOrErrCond.Wait()
}
if s.writeErr != nil {
return 0, s.writeErr
if s.err != nil {
return 0, s.err
}
return len(p), nil
@@ -164,7 +168,7 @@ func (s *stream) Write(p []byte) (int, error) {
func (s *stream) lenOfDataForWriting() protocol.ByteCount {
s.mutex.Lock()
var l protocol.ByteCount
if s.writeErr == nil {
if s.err == nil {
l = protocol.ByteCount(len(s.dataForWriting))
}
s.mutex.Unlock()
@@ -173,7 +177,7 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount {
func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
s.mutex.Lock()
if s.writeErr != nil {
if s.err != nil {
s.mutex.Unlock()
return nil
}
@@ -204,7 +208,7 @@ func (s *stream) Close() error {
func (s *stream) shouldSendFin() bool {
s.mutex.Lock()
res := s.finishedWriting.Get() && !s.finSent && s.writeErr == nil && s.dataForWriting == nil
res := s.finishedWriting.Get() && !s.finSent && s.err == nil && s.dataForWriting == nil
s.mutex.Unlock()
return res
}
@@ -247,12 +251,26 @@ func (s *stream) Cancel(err error) {
s.mutex.Lock()
// errors must not be changed!
if s.readErr == nil {
s.readErr = err
if s.err == nil {
s.err = err
s.newFrameOrErrCond.Signal()
s.doneWritingOrErrCond.Signal()
}
if s.writeErr == nil {
s.writeErr = err
// if s.writeErr == nil {
// s.writeErr = err
// }
s.mutex.Unlock()
}
// resets the stream locally
func (s *stream) Reset(err error) {
s.finishedReading.Set(true)
s.resetLocally.Set(true)
s.mutex.Lock()
// errors must not be changed!
if s.err == nil {
s.err = err
s.newFrameOrErrCond.Signal()
s.doneWritingOrErrCond.Signal()
}
s.mutex.Unlock()
@@ -263,8 +281,8 @@ func (s *stream) RegisterRemoteError(err error) {
s.finishedWriting.Set(true)
s.mutex.Lock()
// errors must not be changed!
if s.writeErr == nil {
s.writeErr = err
if s.err == nil {
s.err = err
s.doneWritingOrErrCond.Signal()
}
s.mutex.Unlock()
@@ -277,7 +295,7 @@ func (s *stream) finishedRead() bool {
func (s *stream) finishedWriteAndSentFin() bool {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.writeErr != nil || (s.finishedWriting.Get() && s.finSent)
return s.err != nil || (s.finishedWriting.Get() && s.finSent)
}
func (s *stream) finished() bool {

View File

@@ -403,33 +403,90 @@ var _ = Describe("Stream", func() {
})
Context("resetting", func() {
testErr := errors.New("received RST_STREAM")
testErr := errors.New("testErr")
It("continues reading after receiving a remote error", func() {
frame := frames.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
}
str.AddStreamFrame(&frame)
str.RegisterRemoteError(testErr)
b := make([]byte, 4)
_, err := str.Read(b)
Expect(err).ToNot(HaveOccurred())
Context("reset by the peer", func() {
It("continues reading after receiving a remote error", func() {
frame := frames.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
}
str.AddStreamFrame(&frame)
str.RegisterRemoteError(testErr)
b := make([]byte, 4)
_, err := str.Read(b)
Expect(err).ToNot(HaveOccurred())
})
It("stops writing after receiving a remote error", func() {
var writeReturned bool
var n int
var err error
go func() {
n, err = str.Write([]byte("foobar"))
writeReturned = true
}()
str.RegisterRemoteError(testErr)
Eventually(func() bool { return writeReturned }).Should(BeTrue())
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
})
It("stops writing after receiving a remote error", func() {
var writeReturned bool
var n int
var err error
Context("reset locally", func() {
It("stops writing", func() {
var writeReturned bool
var n int
var err error
go func() {
n, err = str.Write([]byte("foobar"))
writeReturned = true
}()
str.RegisterRemoteError(testErr)
Eventually(func() bool { return writeReturned }).Should(BeTrue())
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
go func() {
n, err = str.Write([]byte("foobar"))
writeReturned = true
}()
Consistently(func() bool { return writeReturned }).Should(BeFalse())
str.Reset(testErr)
data := str.getDataForWriting(6)
Expect(data).To(BeNil())
Eventually(func() bool { return writeReturned }).Should(BeTrue())
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
It("doesn't allow further writes", func() {
str.Reset(testErr)
n, err := str.Write([]byte("foobar"))
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
It("stops reading", func() {
var readReturned bool
var n int
var err error
go func() {
b := make([]byte, 4)
n, err = str.Read(b)
readReturned = true
}()
Consistently(func() bool { return readReturned }).Should(BeFalse())
str.Reset(testErr)
Eventually(func() bool { return readReturned }).Should(BeTrue())
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
It("doesn't allow further reads", func() {
str.AddStreamFrame(&frames.StreamFrame{
Data: []byte("foobar"),
})
str.Reset(testErr)
b := make([]byte, 6)
n, err := str.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
})
})

View File

@@ -13,4 +13,5 @@ type Stream interface {
io.Closer
StreamID() protocol.StreamID
CloseRemote(offset protocol.ByteCount)
Reset(error)
}