allow stream.Read for streams that a RST was received for

and a lot of code improvements

fixes #385
This commit is contained in:
Marten Seemann
2017-01-05 19:21:16 +07:00
parent 72e9994c9c
commit 4e0ef58bab
7 changed files with 338 additions and 230 deletions

View File

@@ -402,8 +402,8 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error {
return errRstStreamOnInvalidStream
}
shouldSendRst := !str.finishedWriting()
s.closeStreamWithError(str, fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode))
shouldSendRst := !str.finishedWriteAndSentFin()
str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode))
bytesSent, err := s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset)
if err != nil {
return err
@@ -485,15 +485,11 @@ func (s *Session) closeImpl(e error, remoteClose bool) error {
func (s *Session) closeStreamsWithError(err error) {
s.streamsMap.Iterate(func(str *stream) (bool, error) {
s.closeStreamWithError(str, err)
str.Cancel(err)
return true, nil
})
}
func (s *Session) closeStreamWithError(str *stream, err error) {
str.RegisterError(err)
}
func (s *Session) sendPacket() error {
// Repeatedly try sending until we don't have any more data, or run out of the congestion window
for {

View File

@@ -233,7 +233,7 @@ var _ = Describe("Session", func() {
Expect(err).To(MatchError("Error accessing the flowController map."))
})
It("closes streams with error", func() {
It("cancels streams with error", func() {
testErr := errors.New("test")
session.handleStreamFrame(&frames.StreamFrame{
StreamID: 5,
@@ -245,17 +245,18 @@ var _ = Describe("Session", func() {
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := str.Read(p)
Expect(err).ToNot(HaveOccurred())
session.closeStreamsWithError(testErr)
_, err = str.Read(p)
Expect(err).To(MatchError(testErr))
session.garbageCollectStreams()
Expect(session.streamsMap.openStreams).To(HaveLen(1))
Expect(session.streamsMap.openStreams).To(BeEmpty())
str, err = session.streamsMap.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
Expect(str).To(BeNil())
})
It("closes empty streams with error", func() {
It("cancels empty streams with error", func() {
testErr := errors.New("test")
session.GetOrOpenStream(5)
Expect(session.streamsMap.openStreams).To(HaveLen(2))
@@ -301,7 +302,7 @@ var _ = Describe("Session", func() {
})
Context("handling RST_STREAM frames", func() {
It("closes the receiving streams for writing and reading", func() {
It("closes the streams for writing", func() {
s, err := session.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
err = session.handleRstStreamFrame(&frames.RstStreamFrame{
@@ -312,9 +313,25 @@ var _ = Describe("Session", func() {
n, err := s.Write([]byte{0})
Expect(n).To(BeZero())
Expect(err).To(MatchError("RST_STREAM received with code 42"))
n, err = s.Read([]byte{0})
Expect(n).To(BeZero())
Expect(err).To(MatchError("RST_STREAM received with code 42"))
})
It("doesn't close the stream for reading", func() {
s, err := session.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
session.handleStreamFrame(&frames.StreamFrame{
StreamID: 5,
Data: []byte("foobar"),
})
err = session.handleRstStreamFrame(&frames.RstStreamFrame{
StreamID: 5,
ErrorCode: 42,
ByteOffset: 6,
})
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 3)
n, err := s.Read(b)
Expect(n).To(Equal(3))
Expect(err).ToNot(HaveOccurred())
})
It("queues a RST_STERAM frame with the correct offset", func() {

101
stream.go
View File

@@ -4,7 +4,6 @@ import (
"fmt"
"io"
"sync"
"sync/atomic"
"github.com/lucas-clemente/quic-go/flowcontrol"
"github.com/lucas-clemente/quic-go/frames"
@@ -16,6 +15,8 @@ import (
//
// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually.
type stream struct {
mutex sync.Mutex
streamID protocol.StreamID
onData func()
@@ -23,14 +24,13 @@ type stream struct {
writeOffset protocol.ByteCount
readOffset protocol.ByteCount
// Once set, err must not be changed!
err error
mutex sync.Mutex
// Once set, the errors must not be changed!
readErr error
writeErr error
// eof is set if we are finished reading
eof int32 // really a bool
// closed is set when we are finished writing
closed int32 // really a bool
cancelled utils.AtomicBool
finishedReading utils.AtomicBool
finishedWriting utils.AtomicBool
frameQueue *streamFrameSorter
newFrameOrErrCond sync.Cond
@@ -59,7 +59,10 @@ func newStream(StreamID protocol.StreamID, onData func(), flowControlManager flo
// Read implements io.Reader. It is not thread safe!
func (s *stream) Read(p []byte) (int, error) {
if atomic.LoadInt32(&s.eof) != 0 {
if s.cancelled.Get() {
return 0, s.readErr
}
if s.finishedReading.Get() {
return 0, io.EOF
}
@@ -70,14 +73,14 @@ func (s *stream) Read(p []byte) (int, error) {
if frame == nil && bytesRead > 0 {
s.mutex.Unlock()
return bytesRead, s.err
return bytesRead, s.readErr
}
var err error
for {
// Stop waiting on errors
if s.err != nil {
err = s.err
if s.readErr != nil {
err = s.readErr
break
}
if frame != nil {
@@ -90,8 +93,10 @@ func (s *stream) Read(p []byte) (int, error) {
s.mutex.Unlock()
// Here, either frame != nil xor err != nil
// fmt.Printf("err: %#v, frame: %#v\n", err, frame)
if frame == nil {
atomic.StoreInt32(&s.eof, 1)
s.finishedReading.Set(true)
// We have an err and no data, return the error
return bytesRead, err
}
@@ -119,7 +124,7 @@ func (s *stream) Read(p []byte) (int, error) {
s.frameQueue.Pop()
s.mutex.Unlock()
if fin {
atomic.StoreInt32(&s.eof, 1)
s.finishedReading.Set(true)
return bytesRead, io.EOF
}
}
@@ -132,8 +137,8 @@ func (s *stream) Write(p []byte) (int, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.err != nil {
return 0, s.err
if s.writeErr != nil {
return 0, s.writeErr
}
if len(p) == 0 {
@@ -145,12 +150,12 @@ func (s *stream) Write(p []byte) (int, error) {
s.onData()
for s.dataForWriting != nil && s.err == nil {
for s.dataForWriting != nil && s.writeErr == nil {
s.doneWritingOrErrCond.Wait()
}
if s.err != nil {
return 0, s.err
if s.writeErr != nil {
return 0, s.writeErr
}
return len(p), nil
@@ -159,7 +164,7 @@ func (s *stream) Write(p []byte) (int, error) {
func (s *stream) lenOfDataForWriting() protocol.ByteCount {
s.mutex.Lock()
var l protocol.ByteCount
if s.err == nil {
if s.writeErr == nil {
l = protocol.ByteCount(len(s.dataForWriting))
}
s.mutex.Unlock()
@@ -168,7 +173,7 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount {
func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
s.mutex.Lock()
if s.err != nil {
if s.writeErr != nil {
s.mutex.Unlock()
return nil
}
@@ -192,14 +197,14 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
// Close implements io.Closer
func (s *stream) Close() error {
atomic.StoreInt32(&s.closed, 1)
s.finishedWriting.Set(true)
s.onData()
return nil
}
func (s *stream) shouldSendFin() bool {
s.mutex.Lock()
res := atomic.LoadInt32(&s.closed) != 0 && !s.finSent && s.err == nil && s.dataForWriting == nil
res := s.finishedWriting.Get() && !s.finSent && s.writeErr == nil && s.dataForWriting == nil
s.mutex.Unlock()
return res
}
@@ -233,32 +238,50 @@ func (s *stream) CloseRemote(offset protocol.ByteCount) {
s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset})
}
// RegisterError is called by session to indicate that an error occurred and the
// stream should be closed.
func (s *stream) RegisterError(err error) {
atomic.StoreInt32(&s.closed, 1)
// Cancel is called by session to indicate that an error occurred
// The stream should will be closed immediately
func (s *stream) Cancel(err error) {
s.finishedReading.Set(true)
s.finishedWriting.Set(true)
s.cancelled.Set(true)
s.mutex.Lock()
defer s.mutex.Unlock()
if s.err != nil { // s.err must not be changed!
return
}
s.err = err
s.doneWritingOrErrCond.Signal()
// errors must not be changed!
if s.readErr == nil {
s.readErr = err
s.newFrameOrErrCond.Signal()
}
func (s *stream) finishedReading() bool {
return atomic.LoadInt32(&s.eof) != 0
if s.writeErr == nil {
s.writeErr = err
s.doneWritingOrErrCond.Signal()
}
s.mutex.Unlock()
}
func (s *stream) finishedWriting() bool {
// resets the stream remotely
func (s *stream) RegisterRemoteError(err error) {
s.finishedWriting.Set(true)
s.mutex.Lock()
// errors must not be changed!
if s.writeErr == nil {
s.writeErr = err
s.doneWritingOrErrCond.Signal()
}
s.mutex.Unlock()
}
func (s *stream) finishedRead() bool {
return s.finishedReading.Get()
}
func (s *stream) finishedWriteAndSentFin() bool {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.err != nil || (atomic.LoadInt32(&s.closed) != 0 && s.finSent)
return s.writeErr != nil || (s.finishedWriting.Get() && s.finSent)
}
func (s *stream) finished() bool {
return s.finishedReading() && s.finishedWriting()
return s.finishedRead() && s.finishedWriteAndSentFin()
}
func (s *stream) StreamID() protocol.StreamID {

View File

@@ -239,7 +239,7 @@ var _ = Describe("Stream Framer", func() {
Context("sending FINs", func() {
It("sends FINs when streams are closed", func() {
stream1.writeOffset = 42
stream1.closed = 1
stream1.finishedWriting.Set(true)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(1))
Expect(fs[0].StreamID).To(Equal(stream1.streamID))
@@ -250,7 +250,7 @@ var _ = Describe("Stream Framer", func() {
It("sends FINs when flow-control blocked", func() {
stream1.writeOffset = 42
stream1.closed = 1
stream1.finishedWriting.Set(true)
fcm.sendWindowSizes[stream1.StreamID()] = 42
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(1))
@@ -262,7 +262,7 @@ var _ = Describe("Stream Framer", func() {
It("bundles FINs with data", func() {
stream1.dataForWriting = []byte("foobar")
stream1.closed = 1
stream1.finishedWriting.Set(true)
fs := framer.PopStreamFrames(1000)
Expect(fs).To(HaveLen(1))
Expect(fs[0].StreamID).To(Equal(stream1.streamID))
@@ -401,7 +401,7 @@ var _ = Describe("Stream Framer", func() {
frames := framer.PopStreamFrames(1000)
Expect(frames).To(HaveLen(1))
Expect(frames[0].FinBit).To(BeFalse())
stream1.closed = 1
stream1.finishedWriting.Set(true)
frames = framer.PopStreamFrames(1000)
Expect(frames).To(HaveLen(1))
Expect(frames[0].FinBit).To(BeTrue())

View File

@@ -288,156 +288,9 @@ var _ = Describe("Stream", func() {
Expect(err).ToNot(HaveOccurred())
Expect(onDataCalled).To(BeTrue())
})
})
Context("writing", func() {
It("writes and gets all data at once", func(done Done) {
go func() {
n, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
close(done)
}()
Eventually(func() []byte {
str.mutex.Lock()
defer str.mutex.Unlock()
return str.dataForWriting
}).Should(Equal([]byte("foobar")))
Expect(onDataCalled).To(BeTrue())
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(6)))
data := str.getDataForWriting(1000)
Expect(data).To(Equal([]byte("foobar")))
Expect(str.writeOffset).To(Equal(protocol.ByteCount(6)))
Expect(str.dataForWriting).To(BeNil())
})
It("writes and gets data in two turns", func(done Done) {
go func() {
n, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
close(done)
}()
Eventually(func() []byte {
str.mutex.Lock()
defer str.mutex.Unlock()
return str.dataForWriting
}).Should(Equal([]byte("foobar")))
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(6)))
data := str.getDataForWriting(3)
Expect(data).To(Equal([]byte("foo")))
Expect(str.writeOffset).To(Equal(protocol.ByteCount(3)))
Expect(str.dataForWriting).ToNot(BeNil())
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(3)))
data = str.getDataForWriting(3)
Expect(data).To(Equal([]byte("bar")))
Expect(str.writeOffset).To(Equal(protocol.ByteCount(6)))
Expect(str.dataForWriting).To(BeNil())
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(0)))
})
It("returns remote errors", func(done Done) {
testErr := errors.New("test")
str.RegisterError(testErr)
n, err := str.Write([]byte("foo"))
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
close(done)
})
It("getDataForWriting returns nil if no data is available", func() {
Expect(str.getDataForWriting(1000)).To(BeNil())
})
It("copies the slice while writing", func() {
s := []byte("foo")
go func() {
n, err := str.Write(s)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
}()
Eventually(func() protocol.ByteCount { return str.lenOfDataForWriting() }).ShouldNot(BeZero())
s[0] = 'v'
Expect(str.getDataForWriting(3)).To(Equal([]byte("foo")))
})
It("returns when given a nil input", func() {
n, err := str.Write(nil)
Expect(n).To(BeZero())
Expect(err).ToNot(HaveOccurred())
})
It("returns when given an empty slice", func() {
n, err := str.Write([]byte(""))
Expect(n).To(BeZero())
Expect(err).ToNot(HaveOccurred())
})
})
Context("closing", func() {
It("sets closed when calling Close", func() {
str.Close()
Expect(str.closed).ToNot(BeZero())
})
It("allows FIN", func() {
str.Close()
Expect(str.shouldSendFin()).To(BeTrue())
})
It("does not allow FIN when there's still data", func() {
str.dataForWriting = []byte("foobar")
str.Close()
Expect(str.shouldSendFin()).To(BeFalse())
})
It("does not allow FIN when the stream is not closed", func() {
Expect(str.shouldSendFin()).To(BeFalse())
})
It("does not allow FIN after an error", func() {
str.RegisterError(errors.New("test"))
Expect(str.shouldSendFin()).To(BeFalse())
})
It("does not allow FIN twice", func() {
str.Close()
Expect(str.shouldSendFin()).To(BeTrue())
str.sentFin()
Expect(str.shouldSendFin()).To(BeFalse())
})
})
Context("flow control, for receiving", func() {
BeforeEach(func() {
str.flowControlManager = &mockFlowControlHandler{}
})
It("updates the highestReceived value in the flow controller", func() {
frame := frames.StreamFrame{
Offset: 2,
Data: []byte("foobar"),
}
err := str.AddStreamFrame(&frame)
Expect(err).ToNot(HaveOccurred())
Expect(str.flowControlManager.(*mockFlowControlHandler).highestReceivedForStream).To(Equal(str.streamID))
Expect(str.flowControlManager.(*mockFlowControlHandler).highestReceived).To(Equal(protocol.ByteCount(2 + 6)))
})
It("errors when a StreamFrames causes a flow control violation", func() {
testErr := errors.New("flow control violation")
str.flowControlManager.(*mockFlowControlHandler).flowControlViolation = testErr
frame := frames.StreamFrame{
Offset: 2,
Data: []byte("foobar"),
}
err := str.AddStreamFrame(&frame)
Expect(err).To(MatchError(testErr))
})
})
Context("closing", func() {
Context("with fin bit", func() {
Context("with FIN bit", func() {
It("returns EOFs", func() {
frame := frames.StreamFrame{
Offset: 0,
@@ -509,42 +362,191 @@ var _ = Describe("Stream", func() {
})
})
Context("with remote errors", func() {
testErr := errors.New("test error")
It("returns EOF if data is read before", func() {
frame := frames.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
FinBit: true,
}
err := str.AddStreamFrame(&frame)
Expect(err).ToNot(HaveOccurred())
str.RegisterError(testErr)
b := make([]byte, 4)
Context("when CloseRemote is called", func() {
It("closes", func() {
str.CloseRemote(0)
b := make([]byte, 8)
n, err := str.Read(b)
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
n, err = str.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(io.EOF))
})
})
})
It("returns errors", func() {
Context("cancelling the stream", func() {
testErr := errors.New("test error")
It("immediately returns all reads", func() {
var readReturned bool
var n int
var err error
b := make([]byte, 4)
go func() {
n, err = str.Read(b)
readReturned = true
}()
Consistently(func() bool { return readReturned }).Should(BeFalse())
str.Cancel(testErr)
Eventually(func() bool { return readReturned }).Should(BeTrue())
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
It("errors for all following reads", func() {
str.Cancel(testErr)
b := make([]byte, 1)
n, err := str.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
})
})
Context("resetting", func() {
testErr := errors.New("received RST_STREAM")
It("continues reading after receiving a remote error", func() {
frame := frames.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
}
err := str.AddStreamFrame(&frame)
Expect(err).ToNot(HaveOccurred())
str.RegisterError(testErr)
str.AddStreamFrame(&frame)
str.RegisterRemoteError(testErr)
b := make([]byte, 4)
n, err := str.Read(b)
_, err := str.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF}))
n, err = str.Read(b)
})
It("stops writing after receiving a remote error", func() {
var writeReturned bool
var n int
var err error
go func() {
n, err = str.Write([]byte("foobar"))
writeReturned = true
}()
str.RegisterRemoteError(testErr)
Eventually(func() bool { return writeReturned }).Should(BeTrue())
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
})
Context("writing", func() {
It("writes and gets all data at once", func(done Done) {
go func() {
n, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
close(done)
}()
Eventually(func() []byte {
str.mutex.Lock()
defer str.mutex.Unlock()
return str.dataForWriting
}).Should(Equal([]byte("foobar")))
Expect(onDataCalled).To(BeTrue())
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(6)))
data := str.getDataForWriting(1000)
Expect(data).To(Equal([]byte("foobar")))
Expect(str.writeOffset).To(Equal(protocol.ByteCount(6)))
Expect(str.dataForWriting).To(BeNil())
})
It("writes and gets data in two turns", func(done Done) {
go func() {
n, err := str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
close(done)
}()
Eventually(func() []byte {
str.mutex.Lock()
defer str.mutex.Unlock()
return str.dataForWriting
}).Should(Equal([]byte("foobar")))
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(6)))
data := str.getDataForWriting(3)
Expect(data).To(Equal([]byte("foo")))
Expect(str.writeOffset).To(Equal(protocol.ByteCount(3)))
Expect(str.dataForWriting).ToNot(BeNil())
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(3)))
data = str.getDataForWriting(3)
Expect(data).To(Equal([]byte("bar")))
Expect(str.writeOffset).To(Equal(protocol.ByteCount(6)))
Expect(str.dataForWriting).To(BeNil())
Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(0)))
})
It("getDataForWriting returns nil if no data is available", func() {
Expect(str.getDataForWriting(1000)).To(BeNil())
})
It("copies the slice while writing", func() {
s := []byte("foo")
go func() {
n, err := str.Write(s)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
}()
Eventually(func() protocol.ByteCount { return str.lenOfDataForWriting() }).ShouldNot(BeZero())
s[0] = 'v'
Expect(str.getDataForWriting(3)).To(Equal([]byte("foo")))
})
It("returns when given a nil input", func() {
n, err := str.Write(nil)
Expect(n).To(BeZero())
Expect(err).ToNot(HaveOccurred())
})
It("returns when given an empty slice", func() {
n, err := str.Write([]byte(""))
Expect(n).To(BeZero())
Expect(err).ToNot(HaveOccurred())
})
Context("closing", func() {
It("sets finishedWriting when calling Close", func() {
str.Close()
Expect(str.finishedWriting.Get()).To(BeTrue())
})
It("allows FIN", func() {
str.Close()
Expect(str.shouldSendFin()).To(BeTrue())
})
It("does not allow FIN when there's still data", func() {
str.dataForWriting = []byte("foobar")
str.Close()
Expect(str.shouldSendFin()).To(BeFalse())
})
It("does not allow FIN when the stream is not closed", func() {
Expect(str.shouldSendFin()).To(BeFalse())
})
It("does not allow FIN after an error", func() {
str.Cancel(errors.New("test"))
Expect(str.shouldSendFin()).To(BeFalse())
})
It("does not allow FIN twice", func() {
str.Close()
Expect(str.shouldSendFin()).To(BeTrue())
str.sentFin()
Expect(str.shouldSendFin()).To(BeFalse())
})
})
Context("cancelling", func() {
testErr := errors.New("test")
It("returns errors when the stream is cancelled", func() {
str.Cancel(testErr)
n, err := str.Write([]byte("foo"))
Expect(n).To(BeZero())
Expect(err).To(MatchError(testErr))
})
@@ -556,21 +558,40 @@ var _ = Describe("Stream", func() {
}()
Eventually(func() []byte { return str.dataForWriting }).ShouldNot(BeNil())
Expect(str.lenOfDataForWriting()).ToNot(BeZero())
str.RegisterError(testErr)
str.Cancel(testErr)
data := str.getDataForWriting(6)
Expect(data).To(BeNil())
Expect(str.lenOfDataForWriting()).To(BeZero())
})
})
})
Context("when CloseRemote is called", func() {
It("closes", func() {
str.CloseRemote(0)
b := make([]byte, 8)
n, err := str.Read(b)
Expect(n).To(BeZero())
Expect(err).To(MatchError(io.EOF))
})
Context("flow control, for receiving", func() {
BeforeEach(func() {
str.flowControlManager = &mockFlowControlHandler{}
})
It("updates the highestReceived value in the flow controller", func() {
frame := frames.StreamFrame{
Offset: 2,
Data: []byte("foobar"),
}
err := str.AddStreamFrame(&frame)
Expect(err).ToNot(HaveOccurred())
Expect(str.flowControlManager.(*mockFlowControlHandler).highestReceivedForStream).To(Equal(str.streamID))
Expect(str.flowControlManager.(*mockFlowControlHandler).highestReceived).To(Equal(protocol.ByteCount(2 + 6)))
})
It("errors when a StreamFrames causes a flow control violation", func() {
testErr := errors.New("flow control violation")
str.flowControlManager.(*mockFlowControlHandler).flowControlViolation = testErr
frame := frames.StreamFrame{
Offset: 2,
Data: []byte("foobar"),
}
err := str.AddStreamFrame(&frame)
Expect(err).To(MatchError(testErr))
})
})
})

22
utils/atomic_bool.go Normal file
View File

@@ -0,0 +1,22 @@
package utils
import "sync/atomic"
// An AtomicBool is an atomic bool
type AtomicBool struct {
v int32
}
// Set sets the value
func (a *AtomicBool) Set(value bool) {
var n int32
if value {
n = 1
}
atomic.StoreInt32(&a.v, n)
}
// Get gets the value
func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.v) != 0
}

29
utils/atomic_bool_test.go Normal file
View File

@@ -0,0 +1,29 @@
package utils
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Atomic Bool", func() {
var a *AtomicBool
BeforeEach(func() {
a = &AtomicBool{}
})
It("has the right default value", func() {
Expect(a.Get()).To(BeFalse())
})
It("sets the value to true", func() {
a.Set(true)
Expect(a.Get()).To(BeTrue())
})
It("sets the value to false", func() {
a.Set(true)
a.Set(false)
Expect(a.Get()).To(BeFalse())
})
})