forked from quic-go/quic-go
@@ -17,6 +17,7 @@ type mockStream struct {
|
||||
}
|
||||
|
||||
func (mockStream) Close() error { return nil }
|
||||
func (mockStream) Reset(error) { panic("not implemented") }
|
||||
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true }
|
||||
func (s mockStream) StreamID() protocol.StreamID { return s.id }
|
||||
|
||||
|
||||
@@ -109,6 +109,7 @@ func (s *mockStream) Write(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (s *mockStream) Close() error { panic("not implemented") }
|
||||
func (s *mockStream) Reset(error) { panic("not implemented") }
|
||||
func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemented") }
|
||||
func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") }
|
||||
|
||||
|
||||
62
stream.go
62
stream.go
@@ -25,12 +25,12 @@ type stream struct {
|
||||
readOffset protocol.ByteCount
|
||||
|
||||
// Once set, the errors must not be changed!
|
||||
readErr error
|
||||
writeErr error
|
||||
err error
|
||||
|
||||
cancelled utils.AtomicBool
|
||||
finishedReading utils.AtomicBool
|
||||
finishedWriting utils.AtomicBool
|
||||
resetLocally utils.AtomicBool
|
||||
|
||||
frameQueue *streamFrameSorter
|
||||
newFrameOrErrCond sync.Cond
|
||||
@@ -59,8 +59,8 @@ func newStream(StreamID protocol.StreamID, onData func(), flowControlManager flo
|
||||
|
||||
// Read implements io.Reader. It is not thread safe!
|
||||
func (s *stream) Read(p []byte) (int, error) {
|
||||
if s.cancelled.Get() {
|
||||
return 0, s.readErr
|
||||
if s.cancelled.Get() || s.resetLocally.Get() {
|
||||
return 0, s.err
|
||||
}
|
||||
if s.finishedReading.Get() {
|
||||
return 0, io.EOF
|
||||
@@ -73,14 +73,14 @@ func (s *stream) Read(p []byte) (int, error) {
|
||||
|
||||
if frame == nil && bytesRead > 0 {
|
||||
s.mutex.Unlock()
|
||||
return bytesRead, s.readErr
|
||||
return bytesRead, s.err
|
||||
}
|
||||
|
||||
var err error
|
||||
for {
|
||||
// Stop waiting on errors
|
||||
if s.readErr != nil {
|
||||
err = s.readErr
|
||||
if s.err != nil {
|
||||
err = s.err
|
||||
break
|
||||
}
|
||||
if frame != nil {
|
||||
@@ -134,11 +134,15 @@ func (s *stream) Read(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (s *stream) Write(p []byte) (int, error) {
|
||||
if s.resetLocally.Get() {
|
||||
return 0, s.err
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.writeErr != nil {
|
||||
return 0, s.writeErr
|
||||
if s.err != nil {
|
||||
return 0, s.err
|
||||
}
|
||||
|
||||
if len(p) == 0 {
|
||||
@@ -150,12 +154,12 @@ func (s *stream) Write(p []byte) (int, error) {
|
||||
|
||||
s.onData()
|
||||
|
||||
for s.dataForWriting != nil && s.writeErr == nil {
|
||||
for s.dataForWriting != nil && s.err == nil {
|
||||
s.doneWritingOrErrCond.Wait()
|
||||
}
|
||||
|
||||
if s.writeErr != nil {
|
||||
return 0, s.writeErr
|
||||
if s.err != nil {
|
||||
return 0, s.err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
@@ -164,7 +168,7 @@ func (s *stream) Write(p []byte) (int, error) {
|
||||
func (s *stream) lenOfDataForWriting() protocol.ByteCount {
|
||||
s.mutex.Lock()
|
||||
var l protocol.ByteCount
|
||||
if s.writeErr == nil {
|
||||
if s.err == nil {
|
||||
l = protocol.ByteCount(len(s.dataForWriting))
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
@@ -173,7 +177,7 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount {
|
||||
|
||||
func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
|
||||
s.mutex.Lock()
|
||||
if s.writeErr != nil {
|
||||
if s.err != nil {
|
||||
s.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
@@ -204,7 +208,7 @@ func (s *stream) Close() error {
|
||||
|
||||
func (s *stream) shouldSendFin() bool {
|
||||
s.mutex.Lock()
|
||||
res := s.finishedWriting.Get() && !s.finSent && s.writeErr == nil && s.dataForWriting == nil
|
||||
res := s.finishedWriting.Get() && !s.finSent && s.err == nil && s.dataForWriting == nil
|
||||
s.mutex.Unlock()
|
||||
return res
|
||||
}
|
||||
@@ -247,12 +251,26 @@ func (s *stream) Cancel(err error) {
|
||||
|
||||
s.mutex.Lock()
|
||||
// errors must not be changed!
|
||||
if s.readErr == nil {
|
||||
s.readErr = err
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
s.newFrameOrErrCond.Signal()
|
||||
s.doneWritingOrErrCond.Signal()
|
||||
}
|
||||
if s.writeErr == nil {
|
||||
s.writeErr = err
|
||||
// if s.writeErr == nil {
|
||||
// s.writeErr = err
|
||||
// }
|
||||
s.mutex.Unlock()
|
||||
}
|
||||
|
||||
// resets the stream locally
|
||||
func (s *stream) Reset(err error) {
|
||||
s.finishedReading.Set(true)
|
||||
s.resetLocally.Set(true)
|
||||
s.mutex.Lock()
|
||||
// errors must not be changed!
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
s.newFrameOrErrCond.Signal()
|
||||
s.doneWritingOrErrCond.Signal()
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
@@ -263,8 +281,8 @@ func (s *stream) RegisterRemoteError(err error) {
|
||||
s.finishedWriting.Set(true)
|
||||
s.mutex.Lock()
|
||||
// errors must not be changed!
|
||||
if s.writeErr == nil {
|
||||
s.writeErr = err
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
s.doneWritingOrErrCond.Signal()
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
@@ -277,7 +295,7 @@ func (s *stream) finishedRead() bool {
|
||||
func (s *stream) finishedWriteAndSentFin() bool {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
return s.writeErr != nil || (s.finishedWriting.Get() && s.finSent)
|
||||
return s.err != nil || (s.finishedWriting.Get() && s.finSent)
|
||||
}
|
||||
|
||||
func (s *stream) finished() bool {
|
||||
|
||||
103
stream_test.go
103
stream_test.go
@@ -403,33 +403,90 @@ var _ = Describe("Stream", func() {
|
||||
})
|
||||
|
||||
Context("resetting", func() {
|
||||
testErr := errors.New("received RST_STREAM")
|
||||
testErr := errors.New("testErr")
|
||||
|
||||
It("continues reading after receiving a remote error", func() {
|
||||
frame := frames.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
|
||||
}
|
||||
str.AddStreamFrame(&frame)
|
||||
str.RegisterRemoteError(testErr)
|
||||
b := make([]byte, 4)
|
||||
_, err := str.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Context("reset by the peer", func() {
|
||||
It("continues reading after receiving a remote error", func() {
|
||||
frame := frames.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
|
||||
}
|
||||
str.AddStreamFrame(&frame)
|
||||
str.RegisterRemoteError(testErr)
|
||||
b := make([]byte, 4)
|
||||
_, err := str.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("stops writing after receiving a remote error", func() {
|
||||
var writeReturned bool
|
||||
var n int
|
||||
var err error
|
||||
|
||||
go func() {
|
||||
n, err = str.Write([]byte("foobar"))
|
||||
writeReturned = true
|
||||
}()
|
||||
str.RegisterRemoteError(testErr)
|
||||
Eventually(func() bool { return writeReturned }).Should(BeTrue())
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
||||
|
||||
It("stops writing after receiving a remote error", func() {
|
||||
var writeReturned bool
|
||||
var n int
|
||||
var err error
|
||||
Context("reset locally", func() {
|
||||
It("stops writing", func() {
|
||||
var writeReturned bool
|
||||
var n int
|
||||
var err error
|
||||
|
||||
go func() {
|
||||
n, err = str.Write([]byte("foobar"))
|
||||
writeReturned = true
|
||||
}()
|
||||
str.RegisterRemoteError(testErr)
|
||||
Eventually(func() bool { return writeReturned }).Should(BeTrue())
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
go func() {
|
||||
n, err = str.Write([]byte("foobar"))
|
||||
writeReturned = true
|
||||
}()
|
||||
Consistently(func() bool { return writeReturned }).Should(BeFalse())
|
||||
str.Reset(testErr)
|
||||
data := str.getDataForWriting(6)
|
||||
Expect(data).To(BeNil())
|
||||
Eventually(func() bool { return writeReturned }).Should(BeTrue())
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("doesn't allow further writes", func() {
|
||||
str.Reset(testErr)
|
||||
n, err := str.Write([]byte("foobar"))
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("stops reading", func() {
|
||||
var readReturned bool
|
||||
var n int
|
||||
var err error
|
||||
|
||||
go func() {
|
||||
b := make([]byte, 4)
|
||||
n, err = str.Read(b)
|
||||
readReturned = true
|
||||
}()
|
||||
Consistently(func() bool { return readReturned }).Should(BeFalse())
|
||||
str.Reset(testErr)
|
||||
Eventually(func() bool { return readReturned }).Should(BeTrue())
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("doesn't allow further reads", func() {
|
||||
str.AddStreamFrame(&frames.StreamFrame{
|
||||
Data: []byte("foobar"),
|
||||
})
|
||||
str.Reset(testErr)
|
||||
b := make([]byte, 6)
|
||||
n, err := str.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -13,4 +13,5 @@ type Stream interface {
|
||||
io.Closer
|
||||
StreamID() protocol.StreamID
|
||||
CloseRemote(offset protocol.ByteCount)
|
||||
Reset(error)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user