From 4e0ef58babf3f4bf39c2e6b3861de5af1843fda1 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 5 Jan 2017 19:21:16 +0700 Subject: [PATCH] allow stream.Read for streams that a RST was received for and a lot of code improvements fixes #385 --- session.go | 10 +- session_test.go | 31 +++- stream.go | 99 ++++++---- stream_framer_test.go | 8 +- stream_test.go | 369 ++++++++++++++++++++------------------ utils/atomic_bool.go | 22 +++ utils/atomic_bool_test.go | 29 +++ 7 files changed, 338 insertions(+), 230 deletions(-) create mode 100644 utils/atomic_bool.go create mode 100644 utils/atomic_bool_test.go diff --git a/session.go b/session.go index 96c13c97..a846ea11 100644 --- a/session.go +++ b/session.go @@ -402,8 +402,8 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { return errRstStreamOnInvalidStream } - shouldSendRst := !str.finishedWriting() - s.closeStreamWithError(str, fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode)) + shouldSendRst := !str.finishedWriteAndSentFin() + str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode)) bytesSent, err := s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset) if err != nil { return err @@ -485,15 +485,11 @@ func (s *Session) closeImpl(e error, remoteClose bool) error { func (s *Session) closeStreamsWithError(err error) { s.streamsMap.Iterate(func(str *stream) (bool, error) { - s.closeStreamWithError(str, err) + str.Cancel(err) return true, nil }) } -func (s *Session) closeStreamWithError(str *stream, err error) { - str.RegisterError(err) -} - func (s *Session) sendPacket() error { // Repeatedly try sending until we don't have any more data, or run out of the congestion window for { diff --git a/session_test.go b/session_test.go index 186f0f62..0cd123c9 100644 --- a/session_test.go +++ b/session_test.go @@ -233,7 +233,7 @@ var _ = Describe("Session", func() { Expect(err).To(MatchError("Error accessing the flowController map.")) }) - It("closes streams with error", func() { + It("cancels streams with error", func() { testErr := errors.New("test") session.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, @@ -245,17 +245,18 @@ var _ = Describe("Session", func() { Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) _, err := str.Read(p) + Expect(err).ToNot(HaveOccurred()) session.closeStreamsWithError(testErr) _, err = str.Read(p) Expect(err).To(MatchError(testErr)) session.garbageCollectStreams() - Expect(session.streamsMap.openStreams).To(HaveLen(1)) + Expect(session.streamsMap.openStreams).To(BeEmpty()) str, err = session.streamsMap.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) Expect(str).To(BeNil()) }) - It("closes empty streams with error", func() { + It("cancels empty streams with error", func() { testErr := errors.New("test") session.GetOrOpenStream(5) Expect(session.streamsMap.openStreams).To(HaveLen(2)) @@ -301,7 +302,7 @@ var _ = Describe("Session", func() { }) Context("handling RST_STREAM frames", func() { - It("closes the receiving streams for writing and reading", func() { + It("closes the streams for writing", func() { s, err := session.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) err = session.handleRstStreamFrame(&frames.RstStreamFrame{ @@ -312,9 +313,25 @@ var _ = Describe("Session", func() { n, err := s.Write([]byte{0}) Expect(n).To(BeZero()) Expect(err).To(MatchError("RST_STREAM received with code 42")) - n, err = s.Read([]byte{0}) - Expect(n).To(BeZero()) - Expect(err).To(MatchError("RST_STREAM received with code 42")) + }) + + It("doesn't close the stream for reading", func() { + s, err := session.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + session.handleStreamFrame(&frames.StreamFrame{ + StreamID: 5, + Data: []byte("foobar"), + }) + err = session.handleRstStreamFrame(&frames.RstStreamFrame{ + StreamID: 5, + ErrorCode: 42, + ByteOffset: 6, + }) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 3) + n, err := s.Read(b) + Expect(n).To(Equal(3)) + Expect(err).ToNot(HaveOccurred()) }) It("queues a RST_STERAM frame with the correct offset", func() { diff --git a/stream.go b/stream.go index 7b22cb73..f86c7465 100644 --- a/stream.go +++ b/stream.go @@ -4,7 +4,6 @@ import ( "fmt" "io" "sync" - "sync/atomic" "github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/frames" @@ -16,6 +15,8 @@ import ( // // Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. type stream struct { + mutex sync.Mutex + streamID protocol.StreamID onData func() @@ -23,14 +24,13 @@ type stream struct { writeOffset protocol.ByteCount readOffset protocol.ByteCount - // Once set, err must not be changed! - err error - mutex sync.Mutex + // Once set, the errors must not be changed! + readErr error + writeErr error - // eof is set if we are finished reading - eof int32 // really a bool - // closed is set when we are finished writing - closed int32 // really a bool + cancelled utils.AtomicBool + finishedReading utils.AtomicBool + finishedWriting utils.AtomicBool frameQueue *streamFrameSorter newFrameOrErrCond sync.Cond @@ -59,7 +59,10 @@ func newStream(StreamID protocol.StreamID, onData func(), flowControlManager flo // Read implements io.Reader. It is not thread safe! func (s *stream) Read(p []byte) (int, error) { - if atomic.LoadInt32(&s.eof) != 0 { + if s.cancelled.Get() { + return 0, s.readErr + } + if s.finishedReading.Get() { return 0, io.EOF } @@ -70,14 +73,14 @@ func (s *stream) Read(p []byte) (int, error) { if frame == nil && bytesRead > 0 { s.mutex.Unlock() - return bytesRead, s.err + return bytesRead, s.readErr } var err error for { // Stop waiting on errors - if s.err != nil { - err = s.err + if s.readErr != nil { + err = s.readErr break } if frame != nil { @@ -90,8 +93,10 @@ func (s *stream) Read(p []byte) (int, error) { s.mutex.Unlock() // Here, either frame != nil xor err != nil + // fmt.Printf("err: %#v, frame: %#v\n", err, frame) + if frame == nil { - atomic.StoreInt32(&s.eof, 1) + s.finishedReading.Set(true) // We have an err and no data, return the error return bytesRead, err } @@ -119,7 +124,7 @@ func (s *stream) Read(p []byte) (int, error) { s.frameQueue.Pop() s.mutex.Unlock() if fin { - atomic.StoreInt32(&s.eof, 1) + s.finishedReading.Set(true) return bytesRead, io.EOF } } @@ -132,8 +137,8 @@ func (s *stream) Write(p []byte) (int, error) { s.mutex.Lock() defer s.mutex.Unlock() - if s.err != nil { - return 0, s.err + if s.writeErr != nil { + return 0, s.writeErr } if len(p) == 0 { @@ -145,12 +150,12 @@ func (s *stream) Write(p []byte) (int, error) { s.onData() - for s.dataForWriting != nil && s.err == nil { + for s.dataForWriting != nil && s.writeErr == nil { s.doneWritingOrErrCond.Wait() } - if s.err != nil { - return 0, s.err + if s.writeErr != nil { + return 0, s.writeErr } return len(p), nil @@ -159,7 +164,7 @@ func (s *stream) Write(p []byte) (int, error) { func (s *stream) lenOfDataForWriting() protocol.ByteCount { s.mutex.Lock() var l protocol.ByteCount - if s.err == nil { + if s.writeErr == nil { l = protocol.ByteCount(len(s.dataForWriting)) } s.mutex.Unlock() @@ -168,7 +173,7 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount { func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { s.mutex.Lock() - if s.err != nil { + if s.writeErr != nil { s.mutex.Unlock() return nil } @@ -192,14 +197,14 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { // Close implements io.Closer func (s *stream) Close() error { - atomic.StoreInt32(&s.closed, 1) + s.finishedWriting.Set(true) s.onData() return nil } func (s *stream) shouldSendFin() bool { s.mutex.Lock() - res := atomic.LoadInt32(&s.closed) != 0 && !s.finSent && s.err == nil && s.dataForWriting == nil + res := s.finishedWriting.Get() && !s.finSent && s.writeErr == nil && s.dataForWriting == nil s.mutex.Unlock() return res } @@ -233,32 +238,50 @@ func (s *stream) CloseRemote(offset protocol.ByteCount) { s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset}) } -// RegisterError is called by session to indicate that an error occurred and the -// stream should be closed. -func (s *stream) RegisterError(err error) { - atomic.StoreInt32(&s.closed, 1) +// Cancel is called by session to indicate that an error occurred +// The stream should will be closed immediately +func (s *stream) Cancel(err error) { + s.finishedReading.Set(true) + s.finishedWriting.Set(true) + s.cancelled.Set(true) + s.mutex.Lock() - defer s.mutex.Unlock() - if s.err != nil { // s.err must not be changed! - return + // errors must not be changed! + if s.readErr == nil { + s.readErr = err + s.newFrameOrErrCond.Signal() } - s.err = err - s.doneWritingOrErrCond.Signal() - s.newFrameOrErrCond.Signal() + if s.writeErr == nil { + s.writeErr = err + s.doneWritingOrErrCond.Signal() + } + s.mutex.Unlock() } -func (s *stream) finishedReading() bool { - return atomic.LoadInt32(&s.eof) != 0 +// resets the stream remotely +func (s *stream) RegisterRemoteError(err error) { + s.finishedWriting.Set(true) + s.mutex.Lock() + // errors must not be changed! + if s.writeErr == nil { + s.writeErr = err + s.doneWritingOrErrCond.Signal() + } + s.mutex.Unlock() } -func (s *stream) finishedWriting() bool { +func (s *stream) finishedRead() bool { + return s.finishedReading.Get() +} + +func (s *stream) finishedWriteAndSentFin() bool { s.mutex.Lock() defer s.mutex.Unlock() - return s.err != nil || (atomic.LoadInt32(&s.closed) != 0 && s.finSent) + return s.writeErr != nil || (s.finishedWriting.Get() && s.finSent) } func (s *stream) finished() bool { - return s.finishedReading() && s.finishedWriting() + return s.finishedRead() && s.finishedWriteAndSentFin() } func (s *stream) StreamID() protocol.StreamID { diff --git a/stream_framer_test.go b/stream_framer_test.go index 5c239619..a8802acf 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -239,7 +239,7 @@ var _ = Describe("Stream Framer", func() { Context("sending FINs", func() { It("sends FINs when streams are closed", func() { stream1.writeOffset = 42 - stream1.closed = 1 + stream1.finishedWriting.Set(true) fs := framer.PopStreamFrames(1000) Expect(fs).To(HaveLen(1)) Expect(fs[0].StreamID).To(Equal(stream1.streamID)) @@ -250,7 +250,7 @@ var _ = Describe("Stream Framer", func() { It("sends FINs when flow-control blocked", func() { stream1.writeOffset = 42 - stream1.closed = 1 + stream1.finishedWriting.Set(true) fcm.sendWindowSizes[stream1.StreamID()] = 42 fs := framer.PopStreamFrames(1000) Expect(fs).To(HaveLen(1)) @@ -262,7 +262,7 @@ var _ = Describe("Stream Framer", func() { It("bundles FINs with data", func() { stream1.dataForWriting = []byte("foobar") - stream1.closed = 1 + stream1.finishedWriting.Set(true) fs := framer.PopStreamFrames(1000) Expect(fs).To(HaveLen(1)) Expect(fs[0].StreamID).To(Equal(stream1.streamID)) @@ -401,7 +401,7 @@ var _ = Describe("Stream Framer", func() { frames := framer.PopStreamFrames(1000) Expect(frames).To(HaveLen(1)) Expect(frames[0].FinBit).To(BeFalse()) - stream1.closed = 1 + stream1.finishedWriting.Set(true) frames = framer.PopStreamFrames(1000) Expect(frames).To(HaveLen(1)) Expect(frames[0].FinBit).To(BeTrue()) diff --git a/stream_test.go b/stream_test.go index 2b59dc32..7835280b 100644 --- a/stream_test.go +++ b/stream_test.go @@ -288,6 +288,149 @@ var _ = Describe("Stream", func() { Expect(err).ToNot(HaveOccurred()) Expect(onDataCalled).To(BeTrue()) }) + + Context("closing", func() { + Context("with FIN bit", func() { + It("returns EOFs", func() { + frame := frames.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, + FinBit: true, + } + str.AddStreamFrame(&frame) + b := make([]byte, 4) + n, err := str.Read(b) + Expect(err).To(MatchError(io.EOF)) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + n, err = str.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(io.EOF)) + }) + + It("handles out-of-order frames", func() { + frame1 := frames.StreamFrame{ + Offset: 2, + Data: []byte{0xBE, 0xEF}, + FinBit: true, + } + frame2 := frames.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + } + err := str.AddStreamFrame(&frame1) + Expect(err).ToNot(HaveOccurred()) + err = str.AddStreamFrame(&frame2) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 4) + n, err := str.Read(b) + Expect(err).To(MatchError(io.EOF)) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) + n, err = str.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(io.EOF)) + }) + + It("returns EOFs with partial read", func() { + frame := frames.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD}, + FinBit: true, + } + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 4) + n, err := str.Read(b) + Expect(err).To(MatchError(io.EOF)) + Expect(n).To(Equal(2)) + Expect(b[:n]).To(Equal([]byte{0xDE, 0xAD})) + }) + + It("handles immediate FINs", func() { + frame := frames.StreamFrame{ + Offset: 0, + Data: []byte{}, + FinBit: true, + } + err := str.AddStreamFrame(&frame) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 4) + n, err := str.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(io.EOF)) + }) + }) + + Context("when CloseRemote is called", func() { + It("closes", func() { + str.CloseRemote(0) + b := make([]byte, 8) + n, err := str.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(io.EOF)) + }) + }) + }) + + Context("cancelling the stream", func() { + testErr := errors.New("test error") + + It("immediately returns all reads", func() { + var readReturned bool + var n int + var err error + b := make([]byte, 4) + go func() { + n, err = str.Read(b) + readReturned = true + }() + Consistently(func() bool { return readReturned }).Should(BeFalse()) + str.Cancel(testErr) + Eventually(func() bool { return readReturned }).Should(BeTrue()) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(testErr)) + }) + + It("errors for all following reads", func() { + str.Cancel(testErr) + b := make([]byte, 1) + n, err := str.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(testErr)) + }) + }) + }) + + Context("resetting", func() { + testErr := errors.New("received RST_STREAM") + + It("continues reading after receiving a remote error", func() { + frame := frames.StreamFrame{ + Offset: 0, + Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, + } + str.AddStreamFrame(&frame) + str.RegisterRemoteError(testErr) + b := make([]byte, 4) + _, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + }) + + It("stops writing after receiving a remote error", func() { + var writeReturned bool + var n int + var err error + + go func() { + n, err = str.Write([]byte("foobar")) + writeReturned = true + }() + str.RegisterRemoteError(testErr) + Eventually(func() bool { return writeReturned }).Should(BeTrue()) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(testErr)) + }) }) Context("writing", func() { @@ -336,15 +479,6 @@ var _ = Describe("Stream", func() { Expect(str.lenOfDataForWriting()).To(Equal(protocol.ByteCount(0))) }) - It("returns remote errors", func(done Done) { - testErr := errors.New("test") - str.RegisterError(testErr) - n, err := str.Write([]byte("foo")) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - close(done) - }) - It("getDataForWriting returns nil if no data is available", func() { Expect(str.getDataForWriting(1000)).To(BeNil()) }) @@ -372,39 +506,63 @@ var _ = Describe("Stream", func() { Expect(n).To(BeZero()) Expect(err).ToNot(HaveOccurred()) }) - }) - Context("closing", func() { - It("sets closed when calling Close", func() { - str.Close() - Expect(str.closed).ToNot(BeZero()) + Context("closing", func() { + It("sets finishedWriting when calling Close", func() { + str.Close() + Expect(str.finishedWriting.Get()).To(BeTrue()) + }) + + It("allows FIN", func() { + str.Close() + Expect(str.shouldSendFin()).To(BeTrue()) + }) + + It("does not allow FIN when there's still data", func() { + str.dataForWriting = []byte("foobar") + str.Close() + Expect(str.shouldSendFin()).To(BeFalse()) + }) + + It("does not allow FIN when the stream is not closed", func() { + Expect(str.shouldSendFin()).To(BeFalse()) + }) + + It("does not allow FIN after an error", func() { + str.Cancel(errors.New("test")) + Expect(str.shouldSendFin()).To(BeFalse()) + }) + + It("does not allow FIN twice", func() { + str.Close() + Expect(str.shouldSendFin()).To(BeTrue()) + str.sentFin() + Expect(str.shouldSendFin()).To(BeFalse()) + }) }) - It("allows FIN", func() { - str.Close() - Expect(str.shouldSendFin()).To(BeTrue()) - }) + Context("cancelling", func() { + testErr := errors.New("test") - It("does not allow FIN when there's still data", func() { - str.dataForWriting = []byte("foobar") - str.Close() - Expect(str.shouldSendFin()).To(BeFalse()) - }) + It("returns errors when the stream is cancelled", func() { + str.Cancel(testErr) + n, err := str.Write([]byte("foo")) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(testErr)) + }) - It("does not allow FIN when the stream is not closed", func() { - Expect(str.shouldSendFin()).To(BeFalse()) - }) - - It("does not allow FIN after an error", func() { - str.RegisterError(errors.New("test")) - Expect(str.shouldSendFin()).To(BeFalse()) - }) - - It("does not allow FIN twice", func() { - str.Close() - Expect(str.shouldSendFin()).To(BeTrue()) - str.sentFin() - Expect(str.shouldSendFin()).To(BeFalse()) + It("doesn't get data for writing if an error occurred", func() { + go func() { + _, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(testErr)) + }() + Eventually(func() []byte { return str.dataForWriting }).ShouldNot(BeNil()) + Expect(str.lenOfDataForWriting()).ToNot(BeZero()) + str.Cancel(testErr) + data := str.getDataForWriting(6) + Expect(data).To(BeNil()) + Expect(str.lenOfDataForWriting()).To(BeZero()) + }) }) }) @@ -436,141 +594,4 @@ var _ = Describe("Stream", func() { }) }) - Context("closing", func() { - Context("with fin bit", func() { - It("returns EOFs", func() { - frame := frames.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - FinBit: true, - } - str.AddStreamFrame(&frame) - b := make([]byte, 4) - n, err := str.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - n, err = str.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) - - It("handles out-of-order frames", func() { - frame1 := frames.StreamFrame{ - Offset: 2, - Data: []byte{0xBE, 0xEF}, - FinBit: true, - } - frame2 := frames.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - err := str.AddStreamFrame(&frame1) - Expect(err).ToNot(HaveOccurred()) - err = str.AddStreamFrame(&frame2) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := str.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - n, err = str.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) - - It("returns EOFs with partial read", func() { - frame := frames.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - FinBit: true, - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := str.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(2)) - Expect(b[:n]).To(Equal([]byte{0xDE, 0xAD})) - }) - - It("handles immediate FINs", func() { - frame := frames.StreamFrame{ - Offset: 0, - Data: []byte{}, - FinBit: true, - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 4) - n, err := str.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) - }) - - Context("with remote errors", func() { - testErr := errors.New("test error") - - It("returns EOF if data is read before", func() { - frame := frames.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - FinBit: true, - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - str.RegisterError(testErr) - b := make([]byte, 4) - n, err := str.Read(b) - Expect(err).To(MatchError(io.EOF)) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - n, err = str.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) - - It("returns errors", func() { - frame := frames.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - } - err := str.AddStreamFrame(&frame) - Expect(err).ToNot(HaveOccurred()) - str.RegisterError(testErr) - b := make([]byte, 4) - n, err := str.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF})) - n, err = str.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(testErr)) - }) - - It("doesn't get data for writing if an error occurred", func() { - go func() { - _, err := str.Write([]byte("foobar")) - Expect(err).To(MatchError(testErr)) - }() - Eventually(func() []byte { return str.dataForWriting }).ShouldNot(BeNil()) - Expect(str.lenOfDataForWriting()).ToNot(BeZero()) - str.RegisterError(testErr) - data := str.getDataForWriting(6) - Expect(data).To(BeNil()) - Expect(str.lenOfDataForWriting()).To(BeZero()) - }) - }) - - Context("when CloseRemote is called", func() { - It("closes", func() { - str.CloseRemote(0) - b := make([]byte, 8) - n, err := str.Read(b) - Expect(n).To(BeZero()) - Expect(err).To(MatchError(io.EOF)) - }) - }) - }) }) diff --git a/utils/atomic_bool.go b/utils/atomic_bool.go new file mode 100644 index 00000000..cf464250 --- /dev/null +++ b/utils/atomic_bool.go @@ -0,0 +1,22 @@ +package utils + +import "sync/atomic" + +// An AtomicBool is an atomic bool +type AtomicBool struct { + v int32 +} + +// Set sets the value +func (a *AtomicBool) Set(value bool) { + var n int32 + if value { + n = 1 + } + atomic.StoreInt32(&a.v, n) +} + +// Get gets the value +func (a *AtomicBool) Get() bool { + return atomic.LoadInt32(&a.v) != 0 +} diff --git a/utils/atomic_bool_test.go b/utils/atomic_bool_test.go new file mode 100644 index 00000000..83a200c2 --- /dev/null +++ b/utils/atomic_bool_test.go @@ -0,0 +1,29 @@ +package utils + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Atomic Bool", func() { + var a *AtomicBool + + BeforeEach(func() { + a = &AtomicBool{} + }) + + It("has the right default value", func() { + Expect(a.Get()).To(BeFalse()) + }) + + It("sets the value to true", func() { + a.Set(true) + Expect(a.Get()).To(BeTrue()) + }) + + It("sets the value to false", func() { + a.Set(true) + a.Set(false) + Expect(a.Get()).To(BeFalse()) + }) +})