From 8a3f807a1297cafa9d607738a41ebd00834daf92 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 25 Dec 2017 16:32:29 +0700 Subject: [PATCH] immediately delete a stream when it is completed By introducing a callback to the stream, which the stream calls as soon as it is completed, we can get rid of checking every single open stream if it is completed. --- mock_stream_internal_test.go | 12 -- mock_stream_sender_test.go | 10 ++ receive_stream.go | 10 +- receive_stream_test.go | 77 ++++-------- send_stream.go | 10 +- send_stream_test.go | 40 ++---- session.go | 10 +- session_test.go | 53 ++------ stream.go | 65 +++++++++- stream_test.go | 29 ++--- streams_map.go | 71 ++--------- streams_map_test.go | 235 ++++++++--------------------------- 12 files changed, 199 insertions(+), 423 deletions(-) diff --git a/mock_stream_internal_test.go b/mock_stream_internal_test.go index 0722acd0d..6cbc8a971 100644 --- a/mock_stream_internal_test.go +++ b/mock_stream_internal_test.go @@ -169,18 +169,6 @@ func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0) } -// finished mocks base method -func (m *MockStreamI) finished() bool { - ret := m.ctrl.Call(m, "finished") - ret0, _ := ret[0].(bool) - return ret0 -} - -// finished indicates an expected call of finished -func (mr *MockStreamIMockRecorder) finished() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "finished", reflect.TypeOf((*MockStreamI)(nil).finished)) -} - // getWindowUpdate mocks base method func (m *MockStreamI) getWindowUpdate() protocol.ByteCount { ret := m.ctrl.Call(m, "getWindowUpdate") diff --git a/mock_stream_sender_test.go b/mock_stream_sender_test.go index 8549e343b..da3ad8d06 100644 --- a/mock_stream_sender_test.go +++ b/mock_stream_sender_test.go @@ -55,6 +55,16 @@ func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0) } +// onStreamCompleted mocks base method +func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) { + m.ctrl.Call(m, "onStreamCompleted", arg0) +} + +// onStreamCompleted indicates an expected call of onStreamCompleted +func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0) +} + // queueControlFrame mocks base method func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) { m.ctrl.Call(m, "queueControlFrame", arg0) diff --git a/receive_stream.go b/receive_stream.go index eb34f24c3..aae314b1c 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -148,6 +148,7 @@ func (s *receiveStream) Read(p []byte) (int, error) { s.frameQueue.Pop() s.finRead = frame.FinBit if frame.FinBit { + s.sender.onStreamCompleted(s.streamID) return bytesRead, io.EOF } } @@ -219,6 +220,7 @@ func (s *receiveStream) handleRstStreamFrame(frame *wire.RstStreamFrame) error { error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode), } s.signalRead() + s.sender.onStreamCompleted(s.streamID) return nil } @@ -259,14 +261,6 @@ func (s *receiveStream) closeForShutdown(err error) { s.signalRead() } -func (s *receiveStream) finished() bool { - s.mutex.Lock() - defer s.mutex.Unlock() - - return s.closedForShutdown || // if the stream was abruptly closed for shutting down - s.finRead || s.resetRemotely -} - func (s *receiveStream) getWindowUpdate() protocol.ByteCount { return s.flowController.GetWindowUpdate() } diff --git a/receive_stream_test.go b/receive_stream_test.go index 98e1868a4..a6ac9f851 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -33,7 +33,6 @@ var _ = Describe("Receive Stream", func() { timeout := scaleDuration(250 * time.Millisecond) strWithTimeout = gbytes.TimeoutReader(str, timeout) - strWithTimeout = str }) It("gets stream id", func() { @@ -320,12 +319,12 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) mockFC.EXPECT().HasWindowUpdate() - frame := wire.StreamFrame{ + str.handleStreamFrame(&wire.StreamFrame{ Offset: 0, Data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, FinBit: true, - } - str.handleStreamFrame(&frame) + }) + mockSender.EXPECT().onStreamCompleted(streamID) b := make([]byte, 4) n, err := strWithTimeout.Read(b) Expect(err).To(MatchError(io.EOF)) @@ -354,6 +353,7 @@ var _ = Describe("Receive Stream", func() { Expect(err).ToNot(HaveOccurred()) err = str.handleStreamFrame(&frame2) Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onStreamCompleted(streamID) b := make([]byte, 4) n, err := strWithTimeout.Read(b) Expect(err).To(MatchError(io.EOF)) @@ -368,31 +368,30 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) mockFC.EXPECT().HasWindowUpdate() - frame := wire.StreamFrame{ + err := str.handleStreamFrame(&wire.StreamFrame{ Offset: 0, - Data: []byte{0xDE, 0xAD}, + Data: []byte{0xde, 0xad}, FinBit: true, - } - err := str.handleStreamFrame(&frame) + }) Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onStreamCompleted(streamID) b := make([]byte, 4) n, err := strWithTimeout.Read(b) Expect(err).To(MatchError(io.EOF)) Expect(n).To(Equal(2)) - Expect(b[:n]).To(Equal([]byte{0xDE, 0xAD})) + Expect(b[:n]).To(Equal([]byte{0xde, 0xad})) }) It("handles immediate FINs", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) mockFC.EXPECT().HasWindowUpdate() - frame := wire.StreamFrame{ + err := str.handleStreamFrame(&wire.StreamFrame{ Offset: 0, - Data: []byte{}, FinBit: true, - } - err := str.handleStreamFrame(&frame) + }) Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onStreamCompleted(streamID) b := make([]byte, 4) n, err := strWithTimeout.Read(b) Expect(n).To(BeZero()) @@ -405,6 +404,7 @@ var _ = Describe("Receive Stream", func() { mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) mockFC.EXPECT().HasWindowUpdate() str.CloseRemote(0) + mockSender.EXPECT().onStreamCompleted(streamID) b := make([]byte, 8) n, err := strWithTimeout.Read(b) Expect(n).To(BeZero()) @@ -486,6 +486,7 @@ var _ = Describe("Receive Stream", func() { FinBit: true, }) Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onStreamCompleted(streamID) _, err = strWithTimeout.Read(make([]byte, 100)) Expect(err).To(MatchError(io.EOF)) err = str.CancelRead(1234) @@ -526,11 +527,13 @@ var _ = Describe("Receive Stream", func() { close(done) }() Consistently(done).ShouldNot(BeClosed()) + mockSender.EXPECT().onStreamCompleted(streamID) str.handleRstStreamFrame(rst) Eventually(done).Should(BeClosed()) }) It("doesn't allow further calls to Read", func() { + mockSender.EXPECT().onStreamCompleted(streamID) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true) err := str.handleRstStreamFrame(rst) Expect(err).ToNot(HaveOccurred()) @@ -549,6 +552,7 @@ var _ = Describe("Receive Stream", func() { }) It("ignores duplicate RST_STREAM frames", func() { + mockSender.EXPECT().onStreamCompleted(streamID) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2) err := str.handleRstStreamFrame(rst) Expect(err).ToNot(HaveOccurred()) @@ -580,6 +584,7 @@ var _ = Describe("Receive Stream", func() { close(readReturned) }() Consistently(readReturned).ShouldNot(BeClosed()) + mockSender.EXPECT().onStreamCompleted(streamID) err := str.handleRstStreamFrame(rst) Expect(err).ToNot(HaveOccurred()) Eventually(readReturned).Should(BeClosed()) @@ -587,8 +592,11 @@ var _ = Describe("Receive Stream", func() { It("continues reading until the end when receiving a RST_STREAM frame with error code 0", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true).Times(2) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + gomock.InOrder( + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)), + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)), + mockSender.EXPECT().onStreamCompleted(streamID), + ) mockFC.EXPECT().HasWindowUpdate().Times(2) readReturned := make(chan struct{}) go func() { @@ -637,43 +645,4 @@ var _ = Describe("Receive Stream", func() { Expect(str.getWindowUpdate()).To(Equal(protocol.ByteCount(0x100))) }) }) - - Context("saying if it is finished", func() { - finishReading := func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) - err := str.handleStreamFrame(&wire.StreamFrame{FinBit: true}) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 100) - _, err = strWithTimeout.Read(b) - ExpectWithOffset(0, err).To(MatchError(io.EOF)) - } - - It("is finished after it is closed for shutdown", func() { - str.closeForShutdown(errors.New("testErr")) - Expect(str.finished()).To(BeTrue()) - }) - - It("is finished if it is only closed for reading", func() { - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0)) - mockFC.EXPECT().HasWindowUpdate() - finishReading() - Expect(str.finished()).To(BeTrue()) - }) - - // the stream still needs to stay alive until we receive the final offset - // (either by receiving a STREAM frame with FIN, or a RST_STREAM) - It("is not finished after CancelRead", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - err := str.CancelRead(123) - Expect(err).ToNot(HaveOccurred()) - Expect(str.finished()).To(BeFalse()) - }) - - It("is finished after receiving a RST_STREAM frame", func() { - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(3), true) - err := str.handleRstStreamFrame(&wire.RstStreamFrame{ByteOffset: 3}) - Expect(err).ToNot(HaveOccurred()) - Expect(str.finished()).To(BeTrue()) - }) - }) }) diff --git a/send_stream.go b/send_stream.go index 01669e143..d337f756c 100644 --- a/send_stream.go +++ b/send_stream.go @@ -154,6 +154,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFr } if frame.FinBit { s.finSent = true + s.sender.onStreamCompleted(s.streamID) } else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream if isBlocked, offset := s.flowController.IsBlocked(); isBlocked { s.sender.queueControlFrame(&wire.StreamBlockedFrame{ @@ -231,6 +232,7 @@ func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, wr }) // TODO(#991): cancel retransmissions for this stream s.ctxCancel() + s.sender.onStreamCompleted(s.streamID) return nil } @@ -289,14 +291,6 @@ func (s *sendStream) closeForShutdown(err error) { s.ctxCancel() } -func (s *sendStream) finished() bool { - s.mutex.Lock() - defer s.mutex.Unlock() - - return s.closedForShutdown || // if the stream was abruptly closed for shutting down - s.finSent || s.canceledWrite -} - func (s *sendStream) getWriteOffset() protocol.ByteCount { return s.writeOffset } diff --git a/send_stream_test.go b/send_stream_test.go index 8e4df2250..d2718e188 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -247,6 +247,7 @@ var _ = Describe("Send Stream", func() { It("doesn't queue a BLOCKED frame if the stream is flow control blocked, but the frame popped has the FIN bit set", func() { mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for the Write, once for the Close + mockSender.EXPECT().onStreamCompleted(streamID) mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) // don't EXPECT a call to mockFC.IsBlocked @@ -389,6 +390,7 @@ var _ = Describe("Send Stream", func() { It("allows FIN", func() { mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().onStreamCompleted(streamID) str.Close() f, hasMoreData := str.popStreamFrame(1000) Expect(f).ToNot(BeNil()) @@ -409,6 +411,7 @@ var _ = Describe("Send Stream", func() { Expect(f).ToNot(BeNil()) Expect(f.Data).To(Equal([]byte("foo"))) Expect(f.FinBit).To(BeFalse()) + mockSender.EXPECT().onStreamCompleted(streamID) f, _ = str.popStreamFrame(100) Expect(f.Data).To(Equal([]byte("bar"))) Expect(f.FinBit).To(BeTrue()) @@ -423,6 +426,7 @@ var _ = Describe("Send Stream", func() { It("doesn't allow FIN twice", func() { mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().onStreamCompleted(streamID) str.Close() f, _ := str.popStreamFrame(1000) Expect(f).ToNot(BeNil()) @@ -513,6 +517,7 @@ var _ = Describe("Send Stream", func() { ByteOffset: 1234, ErrorCode: 9876, }) + mockSender.EXPECT().onStreamCompleted(streamID) str.writeOffset = 1234 err := str.CancelWrite(9876) Expect(err).ToNot(HaveOccurred()) @@ -520,6 +525,7 @@ var _ = Describe("Send Stream", func() { It("unblocks Write", func() { mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().onStreamCompleted(streamID) mockSender.EXPECT().queueControlFrame(gomock.Any()) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(gomock.Any()) @@ -544,6 +550,7 @@ var _ = Describe("Send Stream", func() { It("cancels the context", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(streamID) Expect(str.Context().Done()).ToNot(BeClosed()) str.CancelWrite(1234) Expect(str.Context().Done()).To(BeClosed()) @@ -551,6 +558,7 @@ var _ = Describe("Send Stream", func() { It("doesn't allow further calls to Write", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(streamID) err := str.CancelWrite(1234) Expect(err).ToNot(HaveOccurred()) _, err = strWithTimeout.Write([]byte("foobar")) @@ -559,6 +567,7 @@ var _ = Describe("Send Stream", func() { It("only cancels once", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(streamID) err := str.CancelWrite(1234) Expect(err).ToNot(HaveOccurred()) err = str.CancelWrite(4321) @@ -580,6 +589,7 @@ var _ = Describe("Send Stream", func() { StreamID: streamID, ErrorCode: errorCodeStopping, }) + mockSender.EXPECT().onStreamCompleted(streamID) str.handleStopSendingFrame(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 101, @@ -600,6 +610,7 @@ var _ = Describe("Send Stream", func() { close(done) }() waitForWrite() + mockSender.EXPECT().onStreamCompleted(streamID) str.handleStopSendingFrame(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 123, @@ -609,6 +620,7 @@ var _ = Describe("Send Stream", func() { It("doesn't allow further calls to Write", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(streamID) str.handleStopSendingFrame(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 123, @@ -621,32 +633,4 @@ var _ = Describe("Send Stream", func() { }) }) }) - - Context("saying if it is finished", func() { - It("is finished after it is closed for shutdown", func() { - str.closeForShutdown(errors.New("testErr")) - Expect(str.finished()).To(BeTrue()) - }) - - It("is finished after Close()", func() { - mockSender.EXPECT().onHasStreamData(streamID) - str.Close() - f, _ := str.popStreamFrame(1000) - Expect(f.FinBit).To(BeTrue()) - Expect(str.finished()).To(BeTrue()) - }) - - It("is finished after CancelWrite", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - err := str.CancelWrite(123) - Expect(err).ToNot(HaveOccurred()) - Expect(str.finished()).To(BeTrue()) - }) - - It("is finished after receiving a STOP_SENDING (and sending a RST_STREAM)", func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID}) - Expect(str.finished()).To(BeTrue()) - }) - }) }) diff --git a/session.go b/session.go index 0e7a2d1fa..6357c5d08 100644 --- a/session.go +++ b/session.go @@ -408,10 +408,6 @@ runLoop: if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout { s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) } - - if err := s.streamsMap.DeleteClosedStreams(); err != nil { - s.closeLocal(err) - } } // only send the error the handshakeChan when the handshake is not completed yet @@ -950,6 +946,12 @@ func (s *session) onHasStreamData(id protocol.StreamID) { s.scheduleSending() } +func (s *session) onStreamCompleted(id protocol.StreamID) { + if err := s.streamsMap.DeleteStream(id); err != nil { + s.Close(err) + } +} + func (s *session) LocalAddr() net.Addr { return s.conn.LocalAddr() } diff --git a/session_test.go b/session_test.go index 85b63236b..82b2d88ad 100644 --- a/session_test.go +++ b/session_test.go @@ -122,7 +122,7 @@ var _ = Describe("Session", func() { ) Expect(err).NotTo(HaveOccurred()) sess = pSess.(*session) - Expect(sess.streamsMap.openStreams).To(BeEmpty()) + Expect(sess.streamsMap.streams).To(BeEmpty()) }) AfterEach(func() { @@ -195,9 +195,6 @@ var _ = Describe("Session", func() { sess.streamsMap.newStream = func(id protocol.StreamID) streamI { str := NewMockStreamI(mockCtrl) str.EXPECT().StreamID().Return(id).AnyTimes() - if id == 1 { - str.EXPECT().finished().AnyTimes() - } return str } }) @@ -247,9 +244,9 @@ var _ = Describe("Session", func() { return str } sess.handleStreamFrame(f1) - numOpenStreams := len(sess.streamsMap.openStreams) + numOpenStreams := len(sess.streamsMap.streams) sess.handleStreamFrame(f2) - Expect(sess.streamsMap.openStreams).To(HaveLen(numOpenStreams)) + Expect(sess.streamsMap.streams).To(HaveLen(numOpenStreams)) }) It("ignores STREAM frames for closed streams", func() { @@ -329,8 +326,7 @@ var _ = Describe("Session", func() { It("ignores the error when the stream is not known", func() { str, err := sess.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) - str.(*MockStreamI).EXPECT().finished().Return(true) - sess.streamsMap.DeleteClosedStreams() + sess.onStreamCompleted(3) str, err = sess.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeNil()) @@ -411,9 +407,7 @@ var _ = Describe("Session", func() { It("ignores MAX_STREAM_DATA frames for a closed stream", func() { str, err := sess.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) - str.(*MockStreamI).EXPECT().finished().Return(true) - err = sess.streamsMap.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) + sess.onStreamCompleted(3) str, err = sess.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeNil()) @@ -457,8 +451,7 @@ var _ = Describe("Session", func() { It("ignores STOP_SENDING frames for a closed stream", func() { str, err := sess.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) - str.(*MockStreamI).EXPECT().finished().Return(true) - err = sess.streamsMap.DeleteClosedStreams() + sess.onStreamCompleted(3) Expect(err).ToNot(HaveOccurred()) str, err = sess.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) @@ -1403,12 +1396,7 @@ var _ = Describe("Session", func() { It("returns a nil-value (not an interface with value nil) for closed streams", func() { str, err := sess.GetOrOpenStream(9) Expect(err).ToNot(HaveOccurred()) - str.Close() - str.(*stream).closeForShutdown(nil) - Expect(str.(*stream).finished()).To(BeTrue()) - err = sess.streamsMap.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(sess.streamsMap.GetOrOpenStream(9)).To(BeNil()) + sess.onStreamCompleted(9) str, err = sess.GetOrOpenStream(9) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeNil()) @@ -1425,31 +1413,6 @@ var _ = Describe("Session", func() { }) }) - Context("counting streams", func() { - It("errors when too many streams are opened", func() { - for i := 0; i < protocol.MaxIncomingStreams; i++ { - _, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) - Expect(err).NotTo(HaveOccurred()) - } - _, err := sess.GetOrOpenStream(protocol.StreamID(301)) - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) - }) - - It("does not error when many streams are opened and closed", func() { - for i := 2; i <= 1000; i++ { - s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) - Expect(err).NotTo(HaveOccurred()) - Expect(s.Close()).To(Succeed()) - f, _ := s.(*stream).popStreamFrame(1000) // trigger "sending" of the FIN bit - Expect(f.FinBit).To(BeTrue()) - s.(*stream).CloseRemote(0) - _, err = s.Read([]byte("a")) - Expect(err).To(MatchError(io.EOF)) - sess.streamsMap.DeleteClosedStreams() - } - }) - }) - Context("ignoring errors", func() { It("ignores duplicate acks", func() { sess.sentPacketHandler.SentPacket(&ackhandler.Packet{ @@ -1522,7 +1485,7 @@ var _ = Describe("Client Session", func() { ) sess = sessP.(*session) Expect(err).ToNot(HaveOccurred()) - Expect(sess.streamsMap.openStreams).To(BeEmpty()) + Expect(sess.streamsMap.streams).To(BeEmpty()) }) AfterEach(func() { diff --git a/stream.go b/stream.go index 80cbb7f10..d13584972 100644 --- a/stream.go +++ b/stream.go @@ -2,6 +2,7 @@ package quic import ( "net" + "sync" "time" "github.com/lucas-clemente/quic-go/internal/flowcontrol" @@ -19,8 +20,34 @@ type streamSender interface { queueControlFrame(wire.Frame) onHasWindowUpdate(protocol.StreamID) onHasStreamData(protocol.StreamID) + onStreamCompleted(protocol.StreamID) } +// Each of the both stream halves gets its own uniStreamSender. +// This is necessary in order to keep track when both halves have been completed. +type uniStreamSender struct { + streamSender + onStreamCompletedImpl func() +} + +func (s *uniStreamSender) queueControlFrame(f wire.Frame) { + s.streamSender.queueControlFrame(f) +} + +func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) { + s.streamSender.onHasWindowUpdate(id) +} + +func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) { + s.streamSender.onHasStreamData(id) +} + +func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { + s.onStreamCompletedImpl() +} + +var _ streamSender = &uniStreamSender{} + type streamI interface { Stream @@ -28,7 +55,6 @@ type streamI interface { handleRstStreamFrame(*wire.RstStreamFrame) error handleStopSendingFrame(*wire.StopSendingFrame) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool) - finished() bool closeForShutdown(error) // methods needed for flow control getWindowUpdate() protocol.ByteCount @@ -42,6 +68,11 @@ type stream struct { receiveStream sendStream + completedMutex sync.Mutex + sender streamSender + receiveStreamCompleted bool + sendStreamCompleted bool + version protocol.VersionNumber } @@ -72,10 +103,28 @@ func newStream(streamID protocol.StreamID, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber, ) *stream { - return &stream{ - sendStream: *newSendStream(streamID, sender, flowController, version), - receiveStream: *newReceiveStream(streamID, sender, flowController), + s := &stream{sender: sender} + senderForSendStream := &uniStreamSender{ + streamSender: sender, + onStreamCompletedImpl: func() { + s.completedMutex.Lock() + s.sendStreamCompleted = true + s.checkIfCompleted() + s.completedMutex.Unlock() + }, } + s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version) + senderForReceiveStream := &uniStreamSender{ + streamSender: sender, + onStreamCompletedImpl: func() { + s.completedMutex.Lock() + s.receiveStreamCompleted = true + s.checkIfCompleted() + s.completedMutex.Unlock() + }, + } + s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController) + return s } // need to define StreamID() here, since both receiveStream and readStream have a StreamID() @@ -120,6 +169,10 @@ func (s *stream) handleRstStreamFrame(frame *wire.RstStreamFrame) error { return nil } -func (s *stream) finished() bool { - return s.sendStream.finished() && s.receiveStream.finished() +// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed. +// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed. +func (s *stream) checkIfCompleted() { + if s.sendStreamCompleted && s.receiveStreamCompleted { + s.sender.onStreamCompleted(s.StreamID()) + } } diff --git a/stream_test.go b/stream_test.go index d33d66ec2..e35af10a8 100644 --- a/stream_test.go +++ b/stream_test.go @@ -1,7 +1,6 @@ package quic import ( - "errors" "io" "os" "strconv" @@ -72,6 +71,7 @@ var _ = Describe("Stream", func() { ByteOffset: 1000, ErrorCode: errorCodeStoppingGQUIC, }) + mockSender.EXPECT().onStreamCompleted(streamID) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true) str.writeOffset = 1000 f := &wire.RstStreamFrame{ @@ -189,26 +189,21 @@ var _ = Describe("Stream", func() { }) }) - Context("saying if it is finished", func() { - It("is finished when both the send and the receive side are finished", func() { - str.receiveStream.closeForShutdown(errors.New("shutdown")) - Expect(str.receiveStream.finished()).To(BeTrue()) - Expect(str.sendStream.finished()).To(BeFalse()) - Expect(str.finished()).To(BeFalse()) + Context("completing", func() { + It("is not completed when only the receive side is completed", func() { + // don't EXPECT a call to mockSender.onStreamCompleted() + str.receiveStream.sender.onStreamCompleted(streamID) }) - It("is not finished when the receive side is finished", func() { - str.sendStream.closeForShutdown(errors.New("shutdown")) - Expect(str.receiveStream.finished()).To(BeFalse()) - Expect(str.sendStream.finished()).To(BeTrue()) - Expect(str.finished()).To(BeFalse()) + It("is not completed when only the send side is completed", func() { + // don't EXPECT a call to mockSender.onStreamCompleted() + str.sendStream.sender.onStreamCompleted(streamID) }) - It("is not finished when the send side is finished", func() { - str.closeForShutdown(errors.New("shutdown")) - Expect(str.receiveStream.finished()).To(BeTrue()) - Expect(str.sendStream.finished()).To(BeTrue()) - Expect(str.finished()).To(BeTrue()) + It("is completed when both sides are completed", func() { + mockSender.EXPECT().onStreamCompleted(streamID) + str.sendStream.sender.onStreamCompleted(streamID) + str.receiveStream.sender.onStreamCompleted(streamID) }) }) }) diff --git a/streams_map.go b/streams_map.go index 907258cfd..956b725d0 100644 --- a/streams_map.go +++ b/streams_map.go @@ -16,9 +16,6 @@ type streamsMap struct { perspective protocol.Perspective streams map[protocol.StreamID]streamI - // needed for round-robin scheduling - openStreams []protocol.StreamID - roundRobinIndex int nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() highestStreamOpenedByPeer protocol.StreamID @@ -51,7 +48,6 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver pro sm := streamsMap{ perspective: pers, streams: make(map[protocol.StreamID]streamI), - openStreams: make([]protocol.StreamID, 0), newStream: newStream, maxIncomingStreams: maxIncomingStreams, } @@ -99,7 +95,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { s, ok := m.streams[id] m.mutex.RUnlock() if ok { - return s, nil // s may be nil + return s, nil } // ... we don't have an existing stream @@ -212,48 +208,19 @@ func (m *streamsMap) AcceptStream() (streamI, error) { return str, nil } -func (m *streamsMap) DeleteClosedStreams() error { +func (m *streamsMap) DeleteStream(id protocol.StreamID) error { m.mutex.Lock() defer m.mutex.Unlock() - - var numDeletedStreams int - // for every closed stream, the streamID is replaced by 0 in the openStreams slice - for i, streamID := range m.openStreams { - str, ok := m.streams[streamID] - if !ok { - return errMapAccess - } - if !str.finished() { - continue - } - numDeletedStreams++ - m.openStreams[i] = 0 - if m.streamInitiatedBy(streamID) == m.perspective { - m.numOutgoingStreams-- - } else { - m.numIncomingStreams-- - } - delete(m.streams, streamID) + _, ok := m.streams[id] + if !ok { + return errMapAccess } - - if numDeletedStreams == 0 { - return nil + delete(m.streams, id) + if m.streamInitiatedBy(id) == m.perspective { + m.numOutgoingStreams-- + } else { + m.numIncomingStreams-- } - - // remove all 0s (representing closed streams) from the openStreams slice - // and adjust the roundRobinIndex - var j int - for i, id := range m.openStreams { - if i != j { - m.openStreams[j] = m.openStreams[i] - } - if id != 0 { - j++ - } else if j < m.roundRobinIndex { - m.roundRobinIndex-- - } - } - m.openStreams = m.openStreams[:len(m.openStreams)-numDeletedStreams] m.openStreamOrErrCond.Signal() return nil } @@ -264,28 +231,16 @@ func (m *streamsMap) Range(cb func(s streamI)) { defer m.mutex.RUnlock() for _, s := range m.streams { - if s != nil { - cb(s) - } + cb(s) } } -func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) { - str, ok := m.streams[streamID] - if !ok { - return true, errMapAccess - } - return fn(str) -} - func (m *streamsMap) putStream(s streamI) error { id := s.StreamID() if _, ok := m.streams[id]; ok { return fmt.Errorf("a stream with ID %d already exists", id) } - m.streams[id] = s - m.openStreams = append(m.openStreams, id) return nil } @@ -295,8 +250,8 @@ func (m *streamsMap) CloseWithError(err error) { m.closeErr = err m.nextStreamOrErrCond.Broadcast() m.openStreamOrErrCond.Broadcast() - for _, s := range m.openStreams { - m.streams[s].closeForShutdown(err) + for _, s := range m.streams { + s.closeForShutdown(err) } } diff --git a/streams_map_test.go b/streams_map_test.go index 3d7f56376..6bb7e71c9 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -7,22 +7,16 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/qerr" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("Streams Map", func() { - var ( - m *streamsMap - finishedStreams map[protocol.StreamID]*gomock.Call - ) + var m *streamsMap newStream := func(id protocol.StreamID) streamI { str := NewMockStreamI(mockCtrl) str.EXPECT().StreamID().Return(id).AnyTimes() - c := str.EXPECT().finished().Return(false).AnyTimes() - finishedStreams[id] = c return str } @@ -30,20 +24,8 @@ var _ = Describe("Streams Map", func() { m = newStreamsMap(newStream, p, v) } - BeforeEach(func() { - finishedStreams = make(map[protocol.StreamID]*gomock.Call) - }) - - AfterEach(func() { - Expect(m.openStreams).To(HaveLen(len(m.streams))) - }) - deleteStream := func(id protocol.StreamID) { - str := m.streams[id] - Expect(str).ToNot(BeNil()) - finishedStreams[id].Return(true) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, m.DeleteStream(id)).To(Succeed()) } Context("getting and creating streams", func() { @@ -521,176 +503,63 @@ var _ = Describe("Streams Map", func() { }) }) - Context("DoS mitigation, iterating and deleting", func() { + Context("Ranging", func() { + It("ranges over all open streams", func() { + setNewStreamsMap(protocol.PerspectiveServer, protocol.VersionWhatever) + var callbackCalledForStream []protocol.StreamID + callback := func(str streamI) { + callbackCalledForStream = append(callbackCalledForStream, str.StreamID()) + sort.Slice(callbackCalledForStream, func(i, j int) bool { + return callbackCalledForStream[i] < callbackCalledForStream[j] + }) + } + + Expect(m.streams).To(BeEmpty()) + // create 5 streams, ids 4 to 8 + callbackCalledForStream = callbackCalledForStream[:0] + for i := 4; i <= 8; i++ { + str := NewMockStreamI(mockCtrl) + str.EXPECT().StreamID().Return(protocol.StreamID(i)).AnyTimes() + err := m.putStream(str) + Expect(err).NotTo(HaveOccurred()) + } + // execute the callback for all streams + m.Range(callback) + Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8})) + }) + }) + + Context("deleting streams", func() { BeforeEach(func() { setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames) }) - closeStream := func(id protocol.StreamID) { - str := m.streams[id] - ExpectWithOffset(1, str).ToNot(BeNil()) - finishedStreams[id].Return(true) - } - - Context("deleting streams", func() { - Context("as a server", func() { - BeforeEach(func() { - m.UpdateMaxStreamLimit(100) - for i := 1; i <= 5; i++ { - if i%2 == 1 { - _, err := m.openRemoteStream(protocol.StreamID(i)) - Expect(err).ToNot(HaveOccurred()) - } else { - _, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - } - } - Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) // 2 and 4 - Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) // 1, 3 and 5 - }) - - It("does not delete streams with Close()", func() { - str, err := m.GetOrOpenStream(55) - Expect(err).ToNot(HaveOccurred()) - str.(*MockStreamI).EXPECT().Close() - str.Close() - err = m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - str, err = m.GetOrOpenStream(55) - Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) - }) - - It("removes the first stream", func() { - closeStream(1) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.openStreams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) - Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) - }) - - It("removes a stream in the middle", func() { - closeStream(3) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.streams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) - Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) - }) - - It("removes a client-initiated stream", func() { - closeStream(2) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.streams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 3, 4, 5})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) - Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) - }) - - It("removes a stream at the end", func() { - closeStream(5) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.openStreams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) - Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) - }) - - It("removes all streams", func() { - for i := 1; i <= 5; i++ { - closeStream(protocol.StreamID(i)) - } - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.streams).To(BeEmpty()) - Expect(m.openStreams).To(BeEmpty()) - Expect(m.numOutgoingStreams).To(BeZero()) - Expect(m.numIncomingStreams).To(BeZero()) - }) - }) - - Context("as a client", func() { - BeforeEach(func() { - setNewStreamsMap(protocol.PerspectiveClient, versionGQUICFrames) - m.UpdateMaxStreamLimit(100) - for i := 1; i <= 5; i++ { - if i%2 == 0 { - _, err := m.openRemoteStream(protocol.StreamID(i)) - Expect(err).ToNot(HaveOccurred()) - } else { - _, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - } - } - Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 2, 5, 4, 7})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(3)) // 3, 5 and 7 - Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) // 2 and 4 - }) - - It("removes a stream that we initiated", func() { - closeStream(3) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.streams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 5, 4, 7})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) - Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) - }) - - It("removes a stream that the server initiated", func() { - closeStream(2) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.openStreams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 5, 4, 7})) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(3)) - Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) - }) - - It("removes all streams", func() { - closeStream(3) - closeStream(2) - closeStream(5) - closeStream(4) - closeStream(7) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.streams).To(BeEmpty()) - Expect(m.openStreams).To(BeEmpty()) - Expect(m.numOutgoingStreams).To(BeZero()) - Expect(m.numIncomingStreams).To(BeZero()) - }) - }) + It("deletes an incoming stream", func() { + _, err := m.GetOrOpenStream(5) // open stream 3 and 5 + Expect(err).ToNot(HaveOccurred()) + Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) + err = m.DeleteStream(3) + Expect(err).ToNot(HaveOccurred()) + Expect(m.streams).To(HaveLen(1)) + Expect(m.streams).To(HaveKey(protocol.StreamID(5))) + Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) }) - Context("Ranging", func() { - // create 5 streams, ids 4 to 8 - var callbackCalledForStream []protocol.StreamID - callback := func(str streamI) { - callbackCalledForStream = append(callbackCalledForStream, str.StreamID()) - sort.Slice(callbackCalledForStream, func(i, j int) bool { return callbackCalledForStream[i] < callbackCalledForStream[j] }) - } + It("deletes an outgoing stream", func() { + m.UpdateMaxStreamLimit(10000) + _, err := m.OpenStream() // open stream 2 + Expect(err).ToNot(HaveOccurred()) + _, err = m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) + err = m.DeleteStream(2) + Expect(err).ToNot(HaveOccurred()) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) + }) - BeforeEach(func() { - callbackCalledForStream = callbackCalledForStream[:0] - for i := 4; i <= 8; i++ { - str := NewMockStreamI(mockCtrl) - str.EXPECT().StreamID().Return(protocol.StreamID(i)).AnyTimes() - err := m.putStream(str) - Expect(err).NotTo(HaveOccurred()) - } - }) - - It("ranges over all open streams", func() { - m.Range(callback) - Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8})) - }) + It("errors when the stream doesn't exist", func() { + err := m.DeleteStream(1337) + Expect(err).To(MatchError(errMapAccess)) }) }) })