forked from quic-go/quic-go
allow stream.Read for streams that a RST was received for
and a lot of code improvements fixes #385
This commit is contained in:
10
session.go
10
session.go
@@ -402,8 +402,8 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error {
|
||||
return errRstStreamOnInvalidStream
|
||||
}
|
||||
|
||||
shouldSendRst := !str.finishedWriting()
|
||||
s.closeStreamWithError(str, fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode))
|
||||
shouldSendRst := !str.finishedWriteAndSentFin()
|
||||
str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode))
|
||||
bytesSent, err := s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -485,15 +485,11 @@ func (s *Session) closeImpl(e error, remoteClose bool) error {
|
||||
|
||||
func (s *Session) closeStreamsWithError(err error) {
|
||||
s.streamsMap.Iterate(func(str *stream) (bool, error) {
|
||||
s.closeStreamWithError(str, err)
|
||||
str.Cancel(err)
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) closeStreamWithError(str *stream, err error) {
|
||||
str.RegisterError(err)
|
||||
}
|
||||
|
||||
func (s *Session) sendPacket() error {
|
||||
// Repeatedly try sending until we don't have any more data, or run out of the congestion window
|
||||
for {
|
||||
|
||||
@@ -233,7 +233,7 @@ var _ = Describe("Session", func() {
|
||||
Expect(err).To(MatchError("Error accessing the flowController map."))
|
||||
})
|
||||
|
||||
It("closes streams with error", func() {
|
||||
It("cancels streams with error", func() {
|
||||
testErr := errors.New("test")
|
||||
session.handleStreamFrame(&frames.StreamFrame{
|
||||
StreamID: 5,
|
||||
@@ -245,17 +245,18 @@ var _ = Describe("Session", func() {
|
||||
Expect(streamCallbackCalled).To(BeTrue())
|
||||
p := make([]byte, 4)
|
||||
_, err := str.Read(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
session.closeStreamsWithError(testErr)
|
||||
_, err = str.Read(p)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
session.garbageCollectStreams()
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(1))
|
||||
Expect(session.streamsMap.openStreams).To(BeEmpty())
|
||||
str, err = session.streamsMap.GetOrOpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(str).To(BeNil())
|
||||
})
|
||||
|
||||
It("closes empty streams with error", func() {
|
||||
It("cancels empty streams with error", func() {
|
||||
testErr := errors.New("test")
|
||||
session.GetOrOpenStream(5)
|
||||
Expect(session.streamsMap.openStreams).To(HaveLen(2))
|
||||
@@ -301,7 +302,7 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
|
||||
Context("handling RST_STREAM frames", func() {
|
||||
It("closes the receiving streams for writing and reading", func() {
|
||||
It("closes the streams for writing", func() {
|
||||
s, err := session.GetOrOpenStream(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = session.handleRstStreamFrame(&frames.RstStreamFrame{
|
||||
@@ -312,9 +313,25 @@ var _ = Describe("Session", func() {
|
||||
n, err := s.Write([]byte{0})
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError("RST_STREAM received with code 42"))
|
||||
n, err = s.Read([]byte{0})
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError("RST_STREAM received with code 42"))
|
||||
})
|
||||
|
||||
It("doesn't close the stream for reading", func() {
|
||||
s, err := session.GetOrOpenStream(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
session.handleStreamFrame(&frames.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: []byte("foobar"),
|
||||
})
|
||||
err = session.handleRstStreamFrame(&frames.RstStreamFrame{
|
||||
StreamID: 5,
|
||||
ErrorCode: 42,
|
||||
ByteOffset: 6,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 3)
|
||||
n, err := s.Read(b)
|
||||
Expect(n).To(Equal(3))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("queues a RST_STERAM frame with the correct offset", func() {
|
||||
|
||||
99
stream.go
99
stream.go
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/flowcontrol"
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
@@ -16,6 +15,8 @@ import (
|
||||
//
|
||||
// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually.
|
||||
type stream struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
streamID protocol.StreamID
|
||||
onData func()
|
||||
|
||||
@@ -23,14 +24,13 @@ type stream struct {
|
||||
writeOffset protocol.ByteCount
|
||||
readOffset protocol.ByteCount
|
||||
|
||||
// Once set, err must not be changed!
|
||||
err error
|
||||
mutex sync.Mutex
|
||||
// Once set, the errors must not be changed!
|
||||
readErr error
|
||||
writeErr error
|
||||
|
||||
// eof is set if we are finished reading
|
||||
eof int32 // really a bool
|
||||
// closed is set when we are finished writing
|
||||
closed int32 // really a bool
|
||||
cancelled utils.AtomicBool
|
||||
finishedReading utils.AtomicBool
|
||||
finishedWriting utils.AtomicBool
|
||||
|
||||
frameQueue *streamFrameSorter
|
||||
newFrameOrErrCond sync.Cond
|
||||
@@ -59,7 +59,10 @@ func newStream(StreamID protocol.StreamID, onData func(), flowControlManager flo
|
||||
|
||||
// Read implements io.Reader. It is not thread safe!
|
||||
func (s *stream) Read(p []byte) (int, error) {
|
||||
if atomic.LoadInt32(&s.eof) != 0 {
|
||||
if s.cancelled.Get() {
|
||||
return 0, s.readErr
|
||||
}
|
||||
if s.finishedReading.Get() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
@@ -70,14 +73,14 @@ func (s *stream) Read(p []byte) (int, error) {
|
||||
|
||||
if frame == nil && bytesRead > 0 {
|
||||
s.mutex.Unlock()
|
||||
return bytesRead, s.err
|
||||
return bytesRead, s.readErr
|
||||
}
|
||||
|
||||
var err error
|
||||
for {
|
||||
// Stop waiting on errors
|
||||
if s.err != nil {
|
||||
err = s.err
|
||||
if s.readErr != nil {
|
||||
err = s.readErr
|
||||
break
|
||||
}
|
||||
if frame != nil {
|
||||
@@ -90,8 +93,10 @@ func (s *stream) Read(p []byte) (int, error) {
|
||||
s.mutex.Unlock()
|
||||
// Here, either frame != nil xor err != nil
|
||||
|
||||
// fmt.Printf("err: %#v, frame: %#v\n", err, frame)
|
||||
|
||||
if frame == nil {
|
||||
atomic.StoreInt32(&s.eof, 1)
|
||||
s.finishedReading.Set(true)
|
||||
// We have an err and no data, return the error
|
||||
return bytesRead, err
|
||||
}
|
||||
@@ -119,7 +124,7 @@ func (s *stream) Read(p []byte) (int, error) {
|
||||
s.frameQueue.Pop()
|
||||
s.mutex.Unlock()
|
||||
if fin {
|
||||
atomic.StoreInt32(&s.eof, 1)
|
||||
s.finishedReading.Set(true)
|
||||
return bytesRead, io.EOF
|
||||
}
|
||||
}
|
||||
@@ -132,8 +137,8 @@ func (s *stream) Write(p []byte) (int, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.err != nil {
|
||||
return 0, s.err
|
||||
if s.writeErr != nil {
|
||||
return 0, s.writeErr
|
||||
}
|
||||
|
||||
if len(p) == 0 {
|
||||
@@ -145,12 +150,12 @@ func (s *stream) Write(p []byte) (int, error) {
|
||||
|
||||
s.onData()
|
||||
|
||||
for s.dataForWriting != nil && s.err == nil {
|
||||
for s.dataForWriting != nil && s.writeErr == nil {
|
||||
s.doneWritingOrErrCond.Wait()
|
||||
}
|
||||
|
||||
if s.err != nil {
|
||||
return 0, s.err
|
||||
if s.writeErr != nil {
|
||||
return 0, s.writeErr
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
@@ -159,7 +164,7 @@ func (s *stream) Write(p []byte) (int, error) {
|
||||
func (s *stream) lenOfDataForWriting() protocol.ByteCount {
|
||||
s.mutex.Lock()
|
||||
var l protocol.ByteCount
|
||||
if s.err == nil {
|
||||
if s.writeErr == nil {
|
||||
l = protocol.ByteCount(len(s.dataForWriting))
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
@@ -168,7 +173,7 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount {
|
||||
|
||||
func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
|
||||
s.mutex.Lock()
|
||||
if s.err != nil {
|
||||
if s.writeErr != nil {
|
||||
s.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
@@ -192,14 +197,14 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
|
||||
|
||||
// Close implements io.Closer
|
||||
func (s *stream) Close() error {
|
||||
atomic.StoreInt32(&s.closed, 1)
|
||||
s.finishedWriting.Set(true)
|
||||
s.onData()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stream) shouldSendFin() bool {
|
||||
s.mutex.Lock()
|
||||
res := atomic.LoadInt32(&s.closed) != 0 && !s.finSent && s.err == nil && s.dataForWriting == nil
|
||||
res := s.finishedWriting.Get() && !s.finSent && s.writeErr == nil && s.dataForWriting == nil
|
||||
s.mutex.Unlock()
|
||||
return res
|
||||
}
|
||||
@@ -233,32 +238,50 @@ func (s *stream) CloseRemote(offset protocol.ByteCount) {
|
||||
s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset})
|
||||
}
|
||||
|
||||
// RegisterError is called by session to indicate that an error occurred and the
|
||||
// stream should be closed.
|
||||
func (s *stream) RegisterError(err error) {
|
||||
atomic.StoreInt32(&s.closed, 1)
|
||||
// Cancel is called by session to indicate that an error occurred
|
||||
// The stream should will be closed immediately
|
||||
func (s *stream) Cancel(err error) {
|
||||
s.finishedReading.Set(true)
|
||||
s.finishedWriting.Set(true)
|
||||
s.cancelled.Set(true)
|
||||
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if s.err != nil { // s.err must not be changed!
|
||||
return
|
||||
}
|
||||
s.err = err
|
||||
s.doneWritingOrErrCond.Signal()
|
||||
// errors must not be changed!
|
||||
if s.readErr == nil {
|
||||
s.readErr = err
|
||||
s.newFrameOrErrCond.Signal()
|
||||
}
|
||||
if s.writeErr == nil {
|
||||
s.writeErr = err
|
||||
s.doneWritingOrErrCond.Signal()
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (s *stream) finishedReading() bool {
|
||||
return atomic.LoadInt32(&s.eof) != 0
|
||||
// resets the stream remotely
|
||||
func (s *stream) RegisterRemoteError(err error) {
|
||||
s.finishedWriting.Set(true)
|
||||
s.mutex.Lock()
|
||||
// errors must not be changed!
|
||||
if s.writeErr == nil {
|
||||
s.writeErr = err
|
||||
s.doneWritingOrErrCond.Signal()
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (s *stream) finishedWriting() bool {
|
||||
func (s *stream) finishedRead() bool {
|
||||
return s.finishedReading.Get()
|
||||
}
|
||||
|
||||
func (s *stream) finishedWriteAndSentFin() bool {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
return s.err != nil || (atomic.LoadInt32(&s.closed) != 0 && s.finSent)
|
||||
return s.writeErr != nil || (s.finishedWriting.Get() && s.finSent)
|
||||
}
|
||||
|
||||
func (s *stream) finished() bool {
|
||||
return s.finishedReading() && s.finishedWriting()
|
||||
return s.finishedRead() && s.finishedWriteAndSentFin()
|
||||
}
|
||||
|
||||
func (s *stream) StreamID() protocol.StreamID {
|
||||
|
||||
@@ -239,7 +239,7 @@ var _ = Describe("Stream Framer", func() {
|
||||
Context("sending FINs", func() {
|
||||
It("sends FINs when streams are closed", func() {
|
||||
stream1.writeOffset = 42
|
||||
stream1.closed = 1
|
||||
stream1.finishedWriting.Set(true)
|
||||
fs := framer.PopStreamFrames(1000)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
Expect(fs[0].StreamID).To(Equal(stream1.streamID))
|
||||
@@ -250,7 +250,7 @@ var _ = Describe("Stream Framer", func() {
|
||||
|
||||
It("sends FINs when flow-control blocked", func() {
|
||||
stream1.writeOffset = 42
|
||||
stream1.closed = 1
|
||||
stream1.finishedWriting.Set(true)
|
||||
fcm.sendWindowSizes[stream1.StreamID()] = 42
|
||||
fs := framer.PopStreamFrames(1000)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
@@ -262,7 +262,7 @@ var _ = Describe("Stream Framer", func() {
|
||||
|
||||
It("bundles FINs with data", func() {
|
||||
stream1.dataForWriting = []byte("foobar")
|
||||
stream1.closed = 1
|
||||
stream1.finishedWriting.Set(true)
|
||||
fs := framer.PopStreamFrames(1000)
|
||||
Expect(fs).To(HaveLen(1))
|
||||
Expect(fs[0].StreamID).To(Equal(stream1.streamID))
|
||||
@@ -401,7 +401,7 @@ var _ = Describe("Stream Framer", func() {
|
||||
frames := framer.PopStreamFrames(1000)
|
||||
Expect(frames).To(HaveLen(1))
|
||||
Expect(frames[0].FinBit).To(BeFalse())
|
||||
stream1.closed = 1
|
||||
stream1.finishedWriting.Set(true)
|
||||
frames = framer.PopStreamFrames(1000)
|
||||
Expect(frames).To(HaveLen(1))
|
||||
Expect(frames[0].FinBit).To(BeTrue())
|
||||
|
||||
385
stream_test.go
385
stream_test.go
@@ -288,156 +288,9 @@ var _ = Describe("Stream", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(onDataCalled).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("writing", func() {
|
||||
It("writes and gets all data at once", func(done Done) {
|
||||
go func() {
|
||||
n, err := str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
close(done)
|
||||
}()
|
||||
Eventually(func() []byte {
|
||||
str.mutex.Lock()
|
||||
defer str.mutex.Unlock()
|
||||
return str.dataForWriting
|
||||
}).Should(Equal([]byte("foobar")))
|
||||
Expect(onDataCalled).To(BeTrue())
|
||||
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(6)))
|
||||
data := str.getDataForWriting(1000)
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
Expect(str.writeOffset).To(Equal(protocol.ByteCount(6)))
|
||||
Expect(str.dataForWriting).To(BeNil())
|
||||
})
|
||||
|
||||
It("writes and gets data in two turns", func(done Done) {
|
||||
go func() {
|
||||
n, err := str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
close(done)
|
||||
}()
|
||||
Eventually(func() []byte {
|
||||
str.mutex.Lock()
|
||||
defer str.mutex.Unlock()
|
||||
return str.dataForWriting
|
||||
}).Should(Equal([]byte("foobar")))
|
||||
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(6)))
|
||||
data := str.getDataForWriting(3)
|
||||
Expect(data).To(Equal([]byte("foo")))
|
||||
Expect(str.writeOffset).To(Equal(protocol.ByteCount(3)))
|
||||
Expect(str.dataForWriting).ToNot(BeNil())
|
||||
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(3)))
|
||||
data = str.getDataForWriting(3)
|
||||
Expect(data).To(Equal([]byte("bar")))
|
||||
Expect(str.writeOffset).To(Equal(protocol.ByteCount(6)))
|
||||
Expect(str.dataForWriting).To(BeNil())
|
||||
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(0)))
|
||||
})
|
||||
|
||||
It("returns remote errors", func(done Done) {
|
||||
testErr := errors.New("test")
|
||||
str.RegisterError(testErr)
|
||||
n, err := str.Write([]byte("foo"))
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
})
|
||||
|
||||
It("getDataForWriting returns nil if no data is available", func() {
|
||||
Expect(str.getDataForWriting(1000)).To(BeNil())
|
||||
})
|
||||
|
||||
It("copies the slice while writing", func() {
|
||||
s := []byte("foo")
|
||||
go func() {
|
||||
n, err := str.Write(s)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
}()
|
||||
Eventually(func() protocol.ByteCount { return str.lenOfDataForWriting() }).ShouldNot(BeZero())
|
||||
s[0] = 'v'
|
||||
Expect(str.getDataForWriting(3)).To(Equal([]byte("foo")))
|
||||
})
|
||||
|
||||
It("returns when given a nil input", func() {
|
||||
n, err := str.Write(nil)
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("returns when given an empty slice", func() {
|
||||
n, err := str.Write([]byte(""))
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("closing", func() {
|
||||
It("sets closed when calling Close", func() {
|
||||
str.Close()
|
||||
Expect(str.closed).ToNot(BeZero())
|
||||
})
|
||||
|
||||
It("allows FIN", func() {
|
||||
str.Close()
|
||||
Expect(str.shouldSendFin()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does not allow FIN when there's still data", func() {
|
||||
str.dataForWriting = []byte("foobar")
|
||||
str.Close()
|
||||
Expect(str.shouldSendFin()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does not allow FIN when the stream is not closed", func() {
|
||||
Expect(str.shouldSendFin()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does not allow FIN after an error", func() {
|
||||
str.RegisterError(errors.New("test"))
|
||||
Expect(str.shouldSendFin()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does not allow FIN twice", func() {
|
||||
str.Close()
|
||||
Expect(str.shouldSendFin()).To(BeTrue())
|
||||
str.sentFin()
|
||||
Expect(str.shouldSendFin()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("flow control, for receiving", func() {
|
||||
BeforeEach(func() {
|
||||
str.flowControlManager = &mockFlowControlHandler{}
|
||||
})
|
||||
|
||||
It("updates the highestReceived value in the flow controller", func() {
|
||||
frame := frames.StreamFrame{
|
||||
Offset: 2,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
err := str.AddStreamFrame(&frame)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.flowControlManager.(*mockFlowControlHandler).highestReceivedForStream).To(Equal(str.streamID))
|
||||
Expect(str.flowControlManager.(*mockFlowControlHandler).highestReceived).To(Equal(protocol.ByteCount(2 + 6)))
|
||||
})
|
||||
|
||||
It("errors when a StreamFrames causes a flow control violation", func() {
|
||||
testErr := errors.New("flow control violation")
|
||||
str.flowControlManager.(*mockFlowControlHandler).flowControlViolation = testErr
|
||||
frame := frames.StreamFrame{
|
||||
Offset: 2,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
err := str.AddStreamFrame(&frame)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
||||
|
||||
Context("closing", func() {
|
||||
Context("with fin bit", func() {
|
||||
Context("with FIN bit", func() {
|
||||
It("returns EOFs", func() {
|
||||
frame := frames.StreamFrame{
|
||||
Offset: 0,
|
||||
@@ -509,42 +362,191 @@ var _ = Describe("Stream", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("with remote errors", func() {
|
||||
testErr := errors.New("test error")
|
||||
|
||||
It("returns EOF if data is read before", func() {
|
||||
frame := frames.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
|
||||
FinBit: true,
|
||||
}
|
||||
err := str.AddStreamFrame(&frame)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.RegisterError(testErr)
|
||||
b := make([]byte, 4)
|
||||
Context("when CloseRemote is called", func() {
|
||||
It("closes", func() {
|
||||
str.CloseRemote(0)
|
||||
b := make([]byte, 8)
|
||||
n, err := str.Read(b)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
Expect(n).To(Equal(4))
|
||||
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
|
||||
n, err = str.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
It("returns errors", func() {
|
||||
Context("cancelling the stream", func() {
|
||||
testErr := errors.New("test error")
|
||||
|
||||
It("immediately returns all reads", func() {
|
||||
var readReturned bool
|
||||
var n int
|
||||
var err error
|
||||
b := make([]byte, 4)
|
||||
go func() {
|
||||
n, err = str.Read(b)
|
||||
readReturned = true
|
||||
}()
|
||||
Consistently(func() bool { return readReturned }).Should(BeFalse())
|
||||
str.Cancel(testErr)
|
||||
Eventually(func() bool { return readReturned }).Should(BeTrue())
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("errors for all following reads", func() {
|
||||
str.Cancel(testErr)
|
||||
b := make([]byte, 1)
|
||||
n, err := str.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("resetting", func() {
|
||||
testErr := errors.New("received RST_STREAM")
|
||||
|
||||
It("continues reading after receiving a remote error", func() {
|
||||
frame := frames.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
|
||||
}
|
||||
err := str.AddStreamFrame(&frame)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.RegisterError(testErr)
|
||||
str.AddStreamFrame(&frame)
|
||||
str.RegisterRemoteError(testErr)
|
||||
b := make([]byte, 4)
|
||||
n, err := str.Read(b)
|
||||
_, err := str.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(4))
|
||||
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
|
||||
n, err = str.Read(b)
|
||||
})
|
||||
|
||||
It("stops writing after receiving a remote error", func() {
|
||||
var writeReturned bool
|
||||
var n int
|
||||
var err error
|
||||
|
||||
go func() {
|
||||
n, err = str.Write([]byte("foobar"))
|
||||
writeReturned = true
|
||||
}()
|
||||
str.RegisterRemoteError(testErr)
|
||||
Eventually(func() bool { return writeReturned }).Should(BeTrue())
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
||||
|
||||
Context("writing", func() {
|
||||
It("writes and gets all data at once", func(done Done) {
|
||||
go func() {
|
||||
n, err := str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
close(done)
|
||||
}()
|
||||
Eventually(func() []byte {
|
||||
str.mutex.Lock()
|
||||
defer str.mutex.Unlock()
|
||||
return str.dataForWriting
|
||||
}).Should(Equal([]byte("foobar")))
|
||||
Expect(onDataCalled).To(BeTrue())
|
||||
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(6)))
|
||||
data := str.getDataForWriting(1000)
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
Expect(str.writeOffset).To(Equal(protocol.ByteCount(6)))
|
||||
Expect(str.dataForWriting).To(BeNil())
|
||||
})
|
||||
|
||||
It("writes and gets data in two turns", func(done Done) {
|
||||
go func() {
|
||||
n, err := str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
close(done)
|
||||
}()
|
||||
Eventually(func() []byte {
|
||||
str.mutex.Lock()
|
||||
defer str.mutex.Unlock()
|
||||
return str.dataForWriting
|
||||
}).Should(Equal([]byte("foobar")))
|
||||
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(6)))
|
||||
data := str.getDataForWriting(3)
|
||||
Expect(data).To(Equal([]byte("foo")))
|
||||
Expect(str.writeOffset).To(Equal(protocol.ByteCount(3)))
|
||||
Expect(str.dataForWriting).ToNot(BeNil())
|
||||
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(3)))
|
||||
data = str.getDataForWriting(3)
|
||||
Expect(data).To(Equal([]byte("bar")))
|
||||
Expect(str.writeOffset).To(Equal(protocol.ByteCount(6)))
|
||||
Expect(str.dataForWriting).To(BeNil())
|
||||
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(0)))
|
||||
})
|
||||
|
||||
It("getDataForWriting returns nil if no data is available", func() {
|
||||
Expect(str.getDataForWriting(1000)).To(BeNil())
|
||||
})
|
||||
|
||||
It("copies the slice while writing", func() {
|
||||
s := []byte("foo")
|
||||
go func() {
|
||||
n, err := str.Write(s)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
}()
|
||||
Eventually(func() protocol.ByteCount { return str.lenOfDataForWriting() }).ShouldNot(BeZero())
|
||||
s[0] = 'v'
|
||||
Expect(str.getDataForWriting(3)).To(Equal([]byte("foo")))
|
||||
})
|
||||
|
||||
It("returns when given a nil input", func() {
|
||||
n, err := str.Write(nil)
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("returns when given an empty slice", func() {
|
||||
n, err := str.Write([]byte(""))
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("closing", func() {
|
||||
It("sets finishedWriting when calling Close", func() {
|
||||
str.Close()
|
||||
Expect(str.finishedWriting.Get()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("allows FIN", func() {
|
||||
str.Close()
|
||||
Expect(str.shouldSendFin()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does not allow FIN when there's still data", func() {
|
||||
str.dataForWriting = []byte("foobar")
|
||||
str.Close()
|
||||
Expect(str.shouldSendFin()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does not allow FIN when the stream is not closed", func() {
|
||||
Expect(str.shouldSendFin()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does not allow FIN after an error", func() {
|
||||
str.Cancel(errors.New("test"))
|
||||
Expect(str.shouldSendFin()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does not allow FIN twice", func() {
|
||||
str.Close()
|
||||
Expect(str.shouldSendFin()).To(BeTrue())
|
||||
str.sentFin()
|
||||
Expect(str.shouldSendFin()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("cancelling", func() {
|
||||
testErr := errors.New("test")
|
||||
|
||||
It("returns errors when the stream is cancelled", func() {
|
||||
str.Cancel(testErr)
|
||||
n, err := str.Write([]byte("foo"))
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
@@ -556,21 +558,40 @@ var _ = Describe("Stream", func() {
|
||||
}()
|
||||
Eventually(func() []byte { return str.dataForWriting }).ShouldNot(BeNil())
|
||||
Expect(str.lenOfDataForWriting()).ToNot(BeZero())
|
||||
str.RegisterError(testErr)
|
||||
str.Cancel(testErr)
|
||||
data := str.getDataForWriting(6)
|
||||
Expect(data).To(BeNil())
|
||||
Expect(str.lenOfDataForWriting()).To(BeZero())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("when CloseRemote is called", func() {
|
||||
It("closes", func() {
|
||||
str.CloseRemote(0)
|
||||
b := make([]byte, 8)
|
||||
n, err := str.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
})
|
||||
Context("flow control, for receiving", func() {
|
||||
BeforeEach(func() {
|
||||
str.flowControlManager = &mockFlowControlHandler{}
|
||||
})
|
||||
|
||||
It("updates the highestReceived value in the flow controller", func() {
|
||||
frame := frames.StreamFrame{
|
||||
Offset: 2,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
err := str.AddStreamFrame(&frame)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.flowControlManager.(*mockFlowControlHandler).highestReceivedForStream).To(Equal(str.streamID))
|
||||
Expect(str.flowControlManager.(*mockFlowControlHandler).highestReceived).To(Equal(protocol.ByteCount(2 + 6)))
|
||||
})
|
||||
|
||||
It("errors when a StreamFrames causes a flow control violation", func() {
|
||||
testErr := errors.New("flow control violation")
|
||||
str.flowControlManager.(*mockFlowControlHandler).flowControlViolation = testErr
|
||||
frame := frames.StreamFrame{
|
||||
Offset: 2,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
err := str.AddStreamFrame(&frame)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
22
utils/atomic_bool.go
Normal file
22
utils/atomic_bool.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package utils
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// An AtomicBool is an atomic bool
|
||||
type AtomicBool struct {
|
||||
v int32
|
||||
}
|
||||
|
||||
// Set sets the value
|
||||
func (a *AtomicBool) Set(value bool) {
|
||||
var n int32
|
||||
if value {
|
||||
n = 1
|
||||
}
|
||||
atomic.StoreInt32(&a.v, n)
|
||||
}
|
||||
|
||||
// Get gets the value
|
||||
func (a *AtomicBool) Get() bool {
|
||||
return atomic.LoadInt32(&a.v) != 0
|
||||
}
|
||||
29
utils/atomic_bool_test.go
Normal file
29
utils/atomic_bool_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Atomic Bool", func() {
|
||||
var a *AtomicBool
|
||||
|
||||
BeforeEach(func() {
|
||||
a = &AtomicBool{}
|
||||
})
|
||||
|
||||
It("has the right default value", func() {
|
||||
Expect(a.Get()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("sets the value to true", func() {
|
||||
a.Set(true)
|
||||
Expect(a.Get()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("sets the value to false", func() {
|
||||
a.Set(true)
|
||||
a.Set(false)
|
||||
Expect(a.Get()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user