forked from quic-go/quic-go
Merge pull request #1081 from lucas-clemente/stream-completed-callback
immediately delete a stream when it is completed
This commit is contained in:
@@ -169,18 +169,6 @@ func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0)
|
||||
}
|
||||
|
||||
// finished mocks base method
|
||||
func (m *MockStreamI) finished() bool {
|
||||
ret := m.ctrl.Call(m, "finished")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// finished indicates an expected call of finished
|
||||
func (mr *MockStreamIMockRecorder) finished() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "finished", reflect.TypeOf((*MockStreamI)(nil).finished))
|
||||
}
|
||||
|
||||
// getWindowUpdate mocks base method
|
||||
func (m *MockStreamI) getWindowUpdate() protocol.ByteCount {
|
||||
ret := m.ctrl.Call(m, "getWindowUpdate")
|
||||
|
||||
@@ -55,6 +55,16 @@ func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0 interface{}) *gom
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0)
|
||||
}
|
||||
|
||||
// onStreamCompleted mocks base method
|
||||
func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) {
|
||||
m.ctrl.Call(m, "onStreamCompleted", arg0)
|
||||
}
|
||||
|
||||
// onStreamCompleted indicates an expected call of onStreamCompleted
|
||||
func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0)
|
||||
}
|
||||
|
||||
// queueControlFrame mocks base method
|
||||
func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) {
|
||||
m.ctrl.Call(m, "queueControlFrame", arg0)
|
||||
|
||||
@@ -148,6 +148,7 @@ func (s *receiveStream) Read(p []byte) (int, error) {
|
||||
s.frameQueue.Pop()
|
||||
s.finRead = frame.FinBit
|
||||
if frame.FinBit {
|
||||
s.sender.onStreamCompleted(s.streamID)
|
||||
return bytesRead, io.EOF
|
||||
}
|
||||
}
|
||||
@@ -219,6 +220,7 @@ func (s *receiveStream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
|
||||
error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode),
|
||||
}
|
||||
s.signalRead()
|
||||
s.sender.onStreamCompleted(s.streamID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -259,14 +261,6 @@ func (s *receiveStream) closeForShutdown(err error) {
|
||||
s.signalRead()
|
||||
}
|
||||
|
||||
func (s *receiveStream) finished() bool {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
return s.closedForShutdown || // if the stream was abruptly closed for shutting down
|
||||
s.finRead || s.resetRemotely
|
||||
}
|
||||
|
||||
func (s *receiveStream) getWindowUpdate() protocol.ByteCount {
|
||||
return s.flowController.GetWindowUpdate()
|
||||
}
|
||||
|
||||
@@ -33,7 +33,6 @@ var _ = Describe("Receive Stream", func() {
|
||||
|
||||
timeout := scaleDuration(250 * time.Millisecond)
|
||||
strWithTimeout = gbytes.TimeoutReader(str, timeout)
|
||||
strWithTimeout = str
|
||||
})
|
||||
|
||||
It("gets stream id", func() {
|
||||
@@ -320,12 +319,12 @@ var _ = Describe("Receive Stream", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
frame := wire.StreamFrame{
|
||||
str.handleStreamFrame(&wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
|
||||
FinBit: true,
|
||||
}
|
||||
str.handleStreamFrame(&frame)
|
||||
})
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
@@ -354,6 +353,7 @@ var _ = Describe("Receive Stream", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.handleStreamFrame(&frame2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
@@ -368,31 +368,30 @@ var _ = Describe("Receive Stream", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
frame := wire.StreamFrame{
|
||||
err := str.handleStreamFrame(&wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{0xDE, 0xAD},
|
||||
Data: []byte{0xde, 0xad},
|
||||
FinBit: true,
|
||||
}
|
||||
err := str.handleStreamFrame(&frame)
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
Expect(n).To(Equal(2))
|
||||
Expect(b[:n]).To(Equal([]byte{0xDE, 0xAD}))
|
||||
Expect(b[:n]).To(Equal([]byte{0xde, 0xad}))
|
||||
})
|
||||
|
||||
It("handles immediate FINs", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
frame := wire.StreamFrame{
|
||||
err := str.handleStreamFrame(&wire.StreamFrame{
|
||||
Offset: 0,
|
||||
Data: []byte{},
|
||||
FinBit: true,
|
||||
}
|
||||
err := str.handleStreamFrame(&frame)
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 4)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
@@ -405,6 +404,7 @@ var _ = Describe("Receive Stream", func() {
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
str.CloseRemote(0)
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
b := make([]byte, 8)
|
||||
n, err := strWithTimeout.Read(b)
|
||||
Expect(n).To(BeZero())
|
||||
@@ -486,6 +486,7 @@ var _ = Describe("Receive Stream", func() {
|
||||
FinBit: true,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
_, err = strWithTimeout.Read(make([]byte, 100))
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
err = str.CancelRead(1234)
|
||||
@@ -526,11 +527,13 @@ var _ = Describe("Receive Stream", func() {
|
||||
close(done)
|
||||
}()
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.handleRstStreamFrame(rst)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't allow further calls to Read", func() {
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
|
||||
err := str.handleRstStreamFrame(rst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -549,6 +552,7 @@ var _ = Describe("Receive Stream", func() {
|
||||
})
|
||||
|
||||
It("ignores duplicate RST_STREAM frames", func() {
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
|
||||
err := str.handleRstStreamFrame(rst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -580,6 +584,7 @@ var _ = Describe("Receive Stream", func() {
|
||||
close(readReturned)
|
||||
}()
|
||||
Consistently(readReturned).ShouldNot(BeClosed())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
err := str.handleRstStreamFrame(rst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(readReturned).Should(BeClosed())
|
||||
@@ -587,8 +592,11 @@ var _ = Describe("Receive Stream", func() {
|
||||
|
||||
It("continues reading until the end when receiving a RST_STREAM frame with error code 0", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true).Times(2)
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
|
||||
gomock.InOrder(
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)),
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)),
|
||||
mockSender.EXPECT().onStreamCompleted(streamID),
|
||||
)
|
||||
mockFC.EXPECT().HasWindowUpdate().Times(2)
|
||||
readReturned := make(chan struct{})
|
||||
go func() {
|
||||
@@ -637,43 +645,4 @@ var _ = Describe("Receive Stream", func() {
|
||||
Expect(str.getWindowUpdate()).To(Equal(protocol.ByteCount(0x100)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("saying if it is finished", func() {
|
||||
finishReading := func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
|
||||
err := str.handleStreamFrame(&wire.StreamFrame{FinBit: true})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 100)
|
||||
_, err = strWithTimeout.Read(b)
|
||||
ExpectWithOffset(0, err).To(MatchError(io.EOF))
|
||||
}
|
||||
|
||||
It("is finished after it is closed for shutdown", func() {
|
||||
str.closeForShutdown(errors.New("testErr"))
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is finished if it is only closed for reading", func() {
|
||||
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
|
||||
mockFC.EXPECT().HasWindowUpdate()
|
||||
finishReading()
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
|
||||
// the stream still needs to stay alive until we receive the final offset
|
||||
// (either by receiving a STREAM frame with FIN, or a RST_STREAM)
|
||||
It("is not finished after CancelRead", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
err := str.CancelRead(123)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.finished()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is finished after receiving a RST_STREAM frame", func() {
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(3), true)
|
||||
err := str.handleRstStreamFrame(&wire.RstStreamFrame{ByteOffset: 3})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -154,6 +154,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFr
|
||||
}
|
||||
if frame.FinBit {
|
||||
s.finSent = true
|
||||
s.sender.onStreamCompleted(s.streamID)
|
||||
} else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream
|
||||
if isBlocked, offset := s.flowController.IsBlocked(); isBlocked {
|
||||
s.sender.queueControlFrame(&wire.StreamBlockedFrame{
|
||||
@@ -231,6 +232,7 @@ func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, wr
|
||||
})
|
||||
// TODO(#991): cancel retransmissions for this stream
|
||||
s.ctxCancel()
|
||||
s.sender.onStreamCompleted(s.streamID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -289,14 +291,6 @@ func (s *sendStream) closeForShutdown(err error) {
|
||||
s.ctxCancel()
|
||||
}
|
||||
|
||||
func (s *sendStream) finished() bool {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
return s.closedForShutdown || // if the stream was abruptly closed for shutting down
|
||||
s.finSent || s.canceledWrite
|
||||
}
|
||||
|
||||
func (s *sendStream) getWriteOffset() protocol.ByteCount {
|
||||
return s.writeOffset
|
||||
}
|
||||
|
||||
@@ -247,6 +247,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("doesn't queue a BLOCKED frame if the stream is flow control blocked, but the frame popped has the FIN bit set", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for the Write, once for the Close
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
|
||||
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
|
||||
// don't EXPECT a call to mockFC.IsBlocked
|
||||
@@ -389,6 +390,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("allows FIN", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.Close()
|
||||
f, hasMoreData := str.popStreamFrame(1000)
|
||||
Expect(f).ToNot(BeNil())
|
||||
@@ -409,6 +411,7 @@ var _ = Describe("Send Stream", func() {
|
||||
Expect(f).ToNot(BeNil())
|
||||
Expect(f.Data).To(Equal([]byte("foo")))
|
||||
Expect(f.FinBit).To(BeFalse())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
f, _ = str.popStreamFrame(100)
|
||||
Expect(f.Data).To(Equal([]byte("bar")))
|
||||
Expect(f.FinBit).To(BeTrue())
|
||||
@@ -423,6 +426,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("doesn't allow FIN twice", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.Close()
|
||||
f, _ := str.popStreamFrame(1000)
|
||||
Expect(f).ToNot(BeNil())
|
||||
@@ -513,6 +517,7 @@ var _ = Describe("Send Stream", func() {
|
||||
ByteOffset: 1234,
|
||||
ErrorCode: 9876,
|
||||
})
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.writeOffset = 1234
|
||||
err := str.CancelWrite(9876)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -520,6 +525,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("unblocks Write", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
|
||||
mockFC.EXPECT().AddBytesSent(gomock.Any())
|
||||
@@ -544,6 +550,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("cancels the context", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
Expect(str.Context().Done()).ToNot(BeClosed())
|
||||
str.CancelWrite(1234)
|
||||
Expect(str.Context().Done()).To(BeClosed())
|
||||
@@ -551,6 +558,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("doesn't allow further calls to Write", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
err := str.CancelWrite(1234)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = strWithTimeout.Write([]byte("foobar"))
|
||||
@@ -559,6 +567,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("only cancels once", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
err := str.CancelWrite(1234)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.CancelWrite(4321)
|
||||
@@ -580,6 +589,7 @@ var _ = Describe("Send Stream", func() {
|
||||
StreamID: streamID,
|
||||
ErrorCode: errorCodeStopping,
|
||||
})
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.handleStopSendingFrame(&wire.StopSendingFrame{
|
||||
StreamID: streamID,
|
||||
ErrorCode: 101,
|
||||
@@ -600,6 +610,7 @@ var _ = Describe("Send Stream", func() {
|
||||
close(done)
|
||||
}()
|
||||
waitForWrite()
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.handleStopSendingFrame(&wire.StopSendingFrame{
|
||||
StreamID: streamID,
|
||||
ErrorCode: 123,
|
||||
@@ -609,6 +620,7 @@ var _ = Describe("Send Stream", func() {
|
||||
|
||||
It("doesn't allow further calls to Write", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.handleStopSendingFrame(&wire.StopSendingFrame{
|
||||
StreamID: streamID,
|
||||
ErrorCode: 123,
|
||||
@@ -621,32 +633,4 @@ var _ = Describe("Send Stream", func() {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("saying if it is finished", func() {
|
||||
It("is finished after it is closed for shutdown", func() {
|
||||
str.closeForShutdown(errors.New("testErr"))
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is finished after Close()", func() {
|
||||
mockSender.EXPECT().onHasStreamData(streamID)
|
||||
str.Close()
|
||||
f, _ := str.popStreamFrame(1000)
|
||||
Expect(f.FinBit).To(BeTrue())
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is finished after CancelWrite", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
err := str.CancelWrite(123)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is finished after receiving a STOP_SENDING (and sending a RST_STREAM)", func() {
|
||||
mockSender.EXPECT().queueControlFrame(gomock.Any())
|
||||
str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID})
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
10
session.go
10
session.go
@@ -408,10 +408,6 @@ runLoop:
|
||||
if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout {
|
||||
s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
|
||||
}
|
||||
|
||||
if err := s.streamsMap.DeleteClosedStreams(); err != nil {
|
||||
s.closeLocal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// only send the error the handshakeChan when the handshake is not completed yet
|
||||
@@ -950,6 +946,12 @@ func (s *session) onHasStreamData(id protocol.StreamID) {
|
||||
s.scheduleSending()
|
||||
}
|
||||
|
||||
func (s *session) onStreamCompleted(id protocol.StreamID) {
|
||||
if err := s.streamsMap.DeleteStream(id); err != nil {
|
||||
s.Close(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) LocalAddr() net.Addr {
|
||||
return s.conn.LocalAddr()
|
||||
}
|
||||
|
||||
@@ -122,7 +122,7 @@ var _ = Describe("Session", func() {
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
sess = pSess.(*session)
|
||||
Expect(sess.streamsMap.openStreams).To(BeEmpty())
|
||||
Expect(sess.streamsMap.streams).To(BeEmpty())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
@@ -195,9 +195,6 @@ var _ = Describe("Session", func() {
|
||||
sess.streamsMap.newStream = func(id protocol.StreamID) streamI {
|
||||
str := NewMockStreamI(mockCtrl)
|
||||
str.EXPECT().StreamID().Return(id).AnyTimes()
|
||||
if id == 1 {
|
||||
str.EXPECT().finished().AnyTimes()
|
||||
}
|
||||
return str
|
||||
}
|
||||
})
|
||||
@@ -247,9 +244,9 @@ var _ = Describe("Session", func() {
|
||||
return str
|
||||
}
|
||||
sess.handleStreamFrame(f1)
|
||||
numOpenStreams := len(sess.streamsMap.openStreams)
|
||||
numOpenStreams := len(sess.streamsMap.streams)
|
||||
sess.handleStreamFrame(f2)
|
||||
Expect(sess.streamsMap.openStreams).To(HaveLen(numOpenStreams))
|
||||
Expect(sess.streamsMap.streams).To(HaveLen(numOpenStreams))
|
||||
})
|
||||
|
||||
It("ignores STREAM frames for closed streams", func() {
|
||||
@@ -329,8 +326,7 @@ var _ = Describe("Session", func() {
|
||||
It("ignores the error when the stream is not known", func() {
|
||||
str, err := sess.GetOrOpenStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.(*MockStreamI).EXPECT().finished().Return(true)
|
||||
sess.streamsMap.DeleteClosedStreams()
|
||||
sess.onStreamCompleted(3)
|
||||
str, err = sess.GetOrOpenStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeNil())
|
||||
@@ -411,9 +407,7 @@ var _ = Describe("Session", func() {
|
||||
It("ignores MAX_STREAM_DATA frames for a closed stream", func() {
|
||||
str, err := sess.GetOrOpenStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.(*MockStreamI).EXPECT().finished().Return(true)
|
||||
err = sess.streamsMap.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sess.onStreamCompleted(3)
|
||||
str, err = sess.GetOrOpenStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeNil())
|
||||
@@ -457,8 +451,7 @@ var _ = Describe("Session", func() {
|
||||
It("ignores STOP_SENDING frames for a closed stream", func() {
|
||||
str, err := sess.GetOrOpenStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.(*MockStreamI).EXPECT().finished().Return(true)
|
||||
err = sess.streamsMap.DeleteClosedStreams()
|
||||
sess.onStreamCompleted(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err = sess.GetOrOpenStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -1403,12 +1396,7 @@ var _ = Describe("Session", func() {
|
||||
It("returns a nil-value (not an interface with value nil) for closed streams", func() {
|
||||
str, err := sess.GetOrOpenStream(9)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.Close()
|
||||
str.(*stream).closeForShutdown(nil)
|
||||
Expect(str.(*stream).finished()).To(BeTrue())
|
||||
err = sess.streamsMap.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sess.streamsMap.GetOrOpenStream(9)).To(BeNil())
|
||||
sess.onStreamCompleted(9)
|
||||
str, err = sess.GetOrOpenStream(9)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).To(BeNil())
|
||||
@@ -1425,31 +1413,6 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("counting streams", func() {
|
||||
It("errors when too many streams are opened", func() {
|
||||
for i := 0; i < protocol.MaxIncomingStreams; i++ {
|
||||
_, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
_, err := sess.GetOrOpenStream(protocol.StreamID(301))
|
||||
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||
})
|
||||
|
||||
It("does not error when many streams are opened and closed", func() {
|
||||
for i := 2; i <= 1000; i++ {
|
||||
s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s.Close()).To(Succeed())
|
||||
f, _ := s.(*stream).popStreamFrame(1000) // trigger "sending" of the FIN bit
|
||||
Expect(f.FinBit).To(BeTrue())
|
||||
s.(*stream).CloseRemote(0)
|
||||
_, err = s.Read([]byte("a"))
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
sess.streamsMap.DeleteClosedStreams()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("ignoring errors", func() {
|
||||
It("ignores duplicate acks", func() {
|
||||
sess.sentPacketHandler.SentPacket(&ackhandler.Packet{
|
||||
@@ -1522,7 +1485,7 @@ var _ = Describe("Client Session", func() {
|
||||
)
|
||||
sess = sessP.(*session)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sess.streamsMap.openStreams).To(BeEmpty())
|
||||
Expect(sess.streamsMap.streams).To(BeEmpty())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
|
||||
65
stream.go
65
stream.go
@@ -2,6 +2,7 @@ package quic
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
@@ -19,8 +20,34 @@ type streamSender interface {
|
||||
queueControlFrame(wire.Frame)
|
||||
onHasWindowUpdate(protocol.StreamID)
|
||||
onHasStreamData(protocol.StreamID)
|
||||
onStreamCompleted(protocol.StreamID)
|
||||
}
|
||||
|
||||
// Each of the both stream halves gets its own uniStreamSender.
|
||||
// This is necessary in order to keep track when both halves have been completed.
|
||||
type uniStreamSender struct {
|
||||
streamSender
|
||||
onStreamCompletedImpl func()
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) queueControlFrame(f wire.Frame) {
|
||||
s.streamSender.queueControlFrame(f)
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) {
|
||||
s.streamSender.onHasWindowUpdate(id)
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) {
|
||||
s.streamSender.onHasStreamData(id)
|
||||
}
|
||||
|
||||
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) {
|
||||
s.onStreamCompletedImpl()
|
||||
}
|
||||
|
||||
var _ streamSender = &uniStreamSender{}
|
||||
|
||||
type streamI interface {
|
||||
Stream
|
||||
|
||||
@@ -28,7 +55,6 @@ type streamI interface {
|
||||
handleRstStreamFrame(*wire.RstStreamFrame) error
|
||||
handleStopSendingFrame(*wire.StopSendingFrame)
|
||||
popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
|
||||
finished() bool
|
||||
closeForShutdown(error)
|
||||
// methods needed for flow control
|
||||
getWindowUpdate() protocol.ByteCount
|
||||
@@ -42,6 +68,11 @@ type stream struct {
|
||||
receiveStream
|
||||
sendStream
|
||||
|
||||
completedMutex sync.Mutex
|
||||
sender streamSender
|
||||
receiveStreamCompleted bool
|
||||
sendStreamCompleted bool
|
||||
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
@@ -72,10 +103,28 @@ func newStream(streamID protocol.StreamID,
|
||||
flowController flowcontrol.StreamFlowController,
|
||||
version protocol.VersionNumber,
|
||||
) *stream {
|
||||
return &stream{
|
||||
sendStream: *newSendStream(streamID, sender, flowController, version),
|
||||
receiveStream: *newReceiveStream(streamID, sender, flowController),
|
||||
s := &stream{sender: sender}
|
||||
senderForSendStream := &uniStreamSender{
|
||||
streamSender: sender,
|
||||
onStreamCompletedImpl: func() {
|
||||
s.completedMutex.Lock()
|
||||
s.sendStreamCompleted = true
|
||||
s.checkIfCompleted()
|
||||
s.completedMutex.Unlock()
|
||||
},
|
||||
}
|
||||
s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version)
|
||||
senderForReceiveStream := &uniStreamSender{
|
||||
streamSender: sender,
|
||||
onStreamCompletedImpl: func() {
|
||||
s.completedMutex.Lock()
|
||||
s.receiveStreamCompleted = true
|
||||
s.checkIfCompleted()
|
||||
s.completedMutex.Unlock()
|
||||
},
|
||||
}
|
||||
s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController)
|
||||
return s
|
||||
}
|
||||
|
||||
// need to define StreamID() here, since both receiveStream and readStream have a StreamID()
|
||||
@@ -120,6 +169,10 @@ func (s *stream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stream) finished() bool {
|
||||
return s.sendStream.finished() && s.receiveStream.finished()
|
||||
// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed.
|
||||
// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed.
|
||||
func (s *stream) checkIfCompleted() {
|
||||
if s.sendStreamCompleted && s.receiveStreamCompleted {
|
||||
s.sender.onStreamCompleted(s.StreamID())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -72,6 +71,7 @@ var _ = Describe("Stream", func() {
|
||||
ByteOffset: 1000,
|
||||
ErrorCode: errorCodeStoppingGQUIC,
|
||||
})
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true)
|
||||
str.writeOffset = 1000
|
||||
f := &wire.RstStreamFrame{
|
||||
@@ -189,26 +189,21 @@ var _ = Describe("Stream", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("saying if it is finished", func() {
|
||||
It("is finished when both the send and the receive side are finished", func() {
|
||||
str.receiveStream.closeForShutdown(errors.New("shutdown"))
|
||||
Expect(str.receiveStream.finished()).To(BeTrue())
|
||||
Expect(str.sendStream.finished()).To(BeFalse())
|
||||
Expect(str.finished()).To(BeFalse())
|
||||
Context("completing", func() {
|
||||
It("is not completed when only the receive side is completed", func() {
|
||||
// don't EXPECT a call to mockSender.onStreamCompleted()
|
||||
str.receiveStream.sender.onStreamCompleted(streamID)
|
||||
})
|
||||
|
||||
It("is not finished when the receive side is finished", func() {
|
||||
str.sendStream.closeForShutdown(errors.New("shutdown"))
|
||||
Expect(str.receiveStream.finished()).To(BeFalse())
|
||||
Expect(str.sendStream.finished()).To(BeTrue())
|
||||
Expect(str.finished()).To(BeFalse())
|
||||
It("is not completed when only the send side is completed", func() {
|
||||
// don't EXPECT a call to mockSender.onStreamCompleted()
|
||||
str.sendStream.sender.onStreamCompleted(streamID)
|
||||
})
|
||||
|
||||
It("is not finished when the send side is finished", func() {
|
||||
str.closeForShutdown(errors.New("shutdown"))
|
||||
Expect(str.receiveStream.finished()).To(BeTrue())
|
||||
Expect(str.sendStream.finished()).To(BeTrue())
|
||||
Expect(str.finished()).To(BeTrue())
|
||||
It("is completed when both sides are completed", func() {
|
||||
mockSender.EXPECT().onStreamCompleted(streamID)
|
||||
str.sendStream.sender.onStreamCompleted(streamID)
|
||||
str.receiveStream.sender.onStreamCompleted(streamID)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -16,9 +16,6 @@ type streamsMap struct {
|
||||
perspective protocol.Perspective
|
||||
|
||||
streams map[protocol.StreamID]streamI
|
||||
// needed for round-robin scheduling
|
||||
openStreams []protocol.StreamID
|
||||
roundRobinIndex int
|
||||
|
||||
nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
|
||||
highestStreamOpenedByPeer protocol.StreamID
|
||||
@@ -51,7 +48,6 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver pro
|
||||
sm := streamsMap{
|
||||
perspective: pers,
|
||||
streams: make(map[protocol.StreamID]streamI),
|
||||
openStreams: make([]protocol.StreamID, 0),
|
||||
newStream: newStream,
|
||||
maxIncomingStreams: maxIncomingStreams,
|
||||
}
|
||||
@@ -99,7 +95,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
|
||||
s, ok := m.streams[id]
|
||||
m.mutex.RUnlock()
|
||||
if ok {
|
||||
return s, nil // s may be nil
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ... we don't have an existing stream
|
||||
@@ -212,48 +208,19 @@ func (m *streamsMap) AcceptStream() (streamI, error) {
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func (m *streamsMap) DeleteClosedStreams() error {
|
||||
func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var numDeletedStreams int
|
||||
// for every closed stream, the streamID is replaced by 0 in the openStreams slice
|
||||
for i, streamID := range m.openStreams {
|
||||
str, ok := m.streams[streamID]
|
||||
if !ok {
|
||||
return errMapAccess
|
||||
}
|
||||
if !str.finished() {
|
||||
continue
|
||||
}
|
||||
numDeletedStreams++
|
||||
m.openStreams[i] = 0
|
||||
if m.streamInitiatedBy(streamID) == m.perspective {
|
||||
m.numOutgoingStreams--
|
||||
} else {
|
||||
m.numIncomingStreams--
|
||||
}
|
||||
delete(m.streams, streamID)
|
||||
_, ok := m.streams[id]
|
||||
if !ok {
|
||||
return errMapAccess
|
||||
}
|
||||
|
||||
if numDeletedStreams == 0 {
|
||||
return nil
|
||||
delete(m.streams, id)
|
||||
if m.streamInitiatedBy(id) == m.perspective {
|
||||
m.numOutgoingStreams--
|
||||
} else {
|
||||
m.numIncomingStreams--
|
||||
}
|
||||
|
||||
// remove all 0s (representing closed streams) from the openStreams slice
|
||||
// and adjust the roundRobinIndex
|
||||
var j int
|
||||
for i, id := range m.openStreams {
|
||||
if i != j {
|
||||
m.openStreams[j] = m.openStreams[i]
|
||||
}
|
||||
if id != 0 {
|
||||
j++
|
||||
} else if j < m.roundRobinIndex {
|
||||
m.roundRobinIndex--
|
||||
}
|
||||
}
|
||||
m.openStreams = m.openStreams[:len(m.openStreams)-numDeletedStreams]
|
||||
m.openStreamOrErrCond.Signal()
|
||||
return nil
|
||||
}
|
||||
@@ -264,28 +231,16 @@ func (m *streamsMap) Range(cb func(s streamI)) {
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
for _, s := range m.streams {
|
||||
if s != nil {
|
||||
cb(s)
|
||||
}
|
||||
cb(s)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) {
|
||||
str, ok := m.streams[streamID]
|
||||
if !ok {
|
||||
return true, errMapAccess
|
||||
}
|
||||
return fn(str)
|
||||
}
|
||||
|
||||
func (m *streamsMap) putStream(s streamI) error {
|
||||
id := s.StreamID()
|
||||
if _, ok := m.streams[id]; ok {
|
||||
return fmt.Errorf("a stream with ID %d already exists", id)
|
||||
}
|
||||
|
||||
m.streams[id] = s
|
||||
m.openStreams = append(m.openStreams, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -295,8 +250,8 @@ func (m *streamsMap) CloseWithError(err error) {
|
||||
m.closeErr = err
|
||||
m.nextStreamOrErrCond.Broadcast()
|
||||
m.openStreamOrErrCond.Broadcast()
|
||||
for _, s := range m.openStreams {
|
||||
m.streams[s].closeForShutdown(err)
|
||||
for _, s := range m.streams {
|
||||
s.closeForShutdown(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,22 +7,16 @@ import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Streams Map", func() {
|
||||
var (
|
||||
m *streamsMap
|
||||
finishedStreams map[protocol.StreamID]*gomock.Call
|
||||
)
|
||||
var m *streamsMap
|
||||
|
||||
newStream := func(id protocol.StreamID) streamI {
|
||||
str := NewMockStreamI(mockCtrl)
|
||||
str.EXPECT().StreamID().Return(id).AnyTimes()
|
||||
c := str.EXPECT().finished().Return(false).AnyTimes()
|
||||
finishedStreams[id] = c
|
||||
return str
|
||||
}
|
||||
|
||||
@@ -30,20 +24,8 @@ var _ = Describe("Streams Map", func() {
|
||||
m = newStreamsMap(newStream, p, v)
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
finishedStreams = make(map[protocol.StreamID]*gomock.Call)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(m.openStreams).To(HaveLen(len(m.streams)))
|
||||
})
|
||||
|
||||
deleteStream := func(id protocol.StreamID) {
|
||||
str := m.streams[id]
|
||||
Expect(str).ToNot(BeNil())
|
||||
finishedStreams[id].Return(true)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ExpectWithOffset(1, m.DeleteStream(id)).To(Succeed())
|
||||
}
|
||||
|
||||
Context("getting and creating streams", func() {
|
||||
@@ -521,176 +503,63 @@ var _ = Describe("Streams Map", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("DoS mitigation, iterating and deleting", func() {
|
||||
Context("Ranging", func() {
|
||||
It("ranges over all open streams", func() {
|
||||
setNewStreamsMap(protocol.PerspectiveServer, protocol.VersionWhatever)
|
||||
var callbackCalledForStream []protocol.StreamID
|
||||
callback := func(str streamI) {
|
||||
callbackCalledForStream = append(callbackCalledForStream, str.StreamID())
|
||||
sort.Slice(callbackCalledForStream, func(i, j int) bool {
|
||||
return callbackCalledForStream[i] < callbackCalledForStream[j]
|
||||
})
|
||||
}
|
||||
|
||||
Expect(m.streams).To(BeEmpty())
|
||||
// create 5 streams, ids 4 to 8
|
||||
callbackCalledForStream = callbackCalledForStream[:0]
|
||||
for i := 4; i <= 8; i++ {
|
||||
str := NewMockStreamI(mockCtrl)
|
||||
str.EXPECT().StreamID().Return(protocol.StreamID(i)).AnyTimes()
|
||||
err := m.putStream(str)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
// execute the callback for all streams
|
||||
m.Range(callback)
|
||||
Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("deleting streams", func() {
|
||||
BeforeEach(func() {
|
||||
setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames)
|
||||
})
|
||||
|
||||
closeStream := func(id protocol.StreamID) {
|
||||
str := m.streams[id]
|
||||
ExpectWithOffset(1, str).ToNot(BeNil())
|
||||
finishedStreams[id].Return(true)
|
||||
}
|
||||
|
||||
Context("deleting streams", func() {
|
||||
Context("as a server", func() {
|
||||
BeforeEach(func() {
|
||||
m.UpdateMaxStreamLimit(100)
|
||||
for i := 1; i <= 5; i++ {
|
||||
if i%2 == 1 {
|
||||
_, err := m.openRemoteStream(protocol.StreamID(i))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
} else {
|
||||
_, err := m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
}
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) // 2 and 4
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) // 1, 3 and 5
|
||||
})
|
||||
|
||||
It("does not delete streams with Close()", func() {
|
||||
str, err := m.GetOrOpenStream(55)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.(*MockStreamI).EXPECT().Close()
|
||||
str.Close()
|
||||
err = m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err = m.GetOrOpenStream(55)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("removes the first stream", func() {
|
||||
closeStream(1)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.openStreams).To(HaveLen(4))
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
|
||||
})
|
||||
|
||||
It("removes a stream in the middle", func() {
|
||||
closeStream(3)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.streams).To(HaveLen(4))
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
|
||||
})
|
||||
|
||||
It("removes a client-initiated stream", func() {
|
||||
closeStream(2)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.streams).To(HaveLen(4))
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 3, 4, 5}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(3))
|
||||
})
|
||||
|
||||
It("removes a stream at the end", func() {
|
||||
closeStream(5)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.openStreams).To(HaveLen(4))
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
|
||||
})
|
||||
|
||||
It("removes all streams", func() {
|
||||
for i := 1; i <= 5; i++ {
|
||||
closeStream(protocol.StreamID(i))
|
||||
}
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.streams).To(BeEmpty())
|
||||
Expect(m.openStreams).To(BeEmpty())
|
||||
Expect(m.numOutgoingStreams).To(BeZero())
|
||||
Expect(m.numIncomingStreams).To(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
Context("as a client", func() {
|
||||
BeforeEach(func() {
|
||||
setNewStreamsMap(protocol.PerspectiveClient, versionGQUICFrames)
|
||||
m.UpdateMaxStreamLimit(100)
|
||||
for i := 1; i <= 5; i++ {
|
||||
if i%2 == 0 {
|
||||
_, err := m.openRemoteStream(protocol.StreamID(i))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
} else {
|
||||
_, err := m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
}
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 2, 5, 4, 7}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(3)) // 3, 5 and 7
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) // 2 and 4
|
||||
})
|
||||
|
||||
It("removes a stream that we initiated", func() {
|
||||
closeStream(3)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.streams).To(HaveLen(4))
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 5, 4, 7}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
|
||||
})
|
||||
|
||||
It("removes a stream that the server initiated", func() {
|
||||
closeStream(2)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.openStreams).To(HaveLen(4))
|
||||
Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 5, 4, 7}))
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(3))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
|
||||
})
|
||||
|
||||
It("removes all streams", func() {
|
||||
closeStream(3)
|
||||
closeStream(2)
|
||||
closeStream(5)
|
||||
closeStream(4)
|
||||
closeStream(7)
|
||||
err := m.DeleteClosedStreams()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.streams).To(BeEmpty())
|
||||
Expect(m.openStreams).To(BeEmpty())
|
||||
Expect(m.numOutgoingStreams).To(BeZero())
|
||||
Expect(m.numIncomingStreams).To(BeZero())
|
||||
})
|
||||
})
|
||||
It("deletes an incoming stream", func() {
|
||||
_, err := m.GetOrOpenStream(5) // open stream 3 and 5
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
|
||||
err = m.DeleteStream(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.streams).To(HaveLen(1))
|
||||
Expect(m.streams).To(HaveKey(protocol.StreamID(5)))
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
|
||||
})
|
||||
|
||||
Context("Ranging", func() {
|
||||
// create 5 streams, ids 4 to 8
|
||||
var callbackCalledForStream []protocol.StreamID
|
||||
callback := func(str streamI) {
|
||||
callbackCalledForStream = append(callbackCalledForStream, str.StreamID())
|
||||
sort.Slice(callbackCalledForStream, func(i, j int) bool { return callbackCalledForStream[i] < callbackCalledForStream[j] })
|
||||
}
|
||||
It("deletes an outgoing stream", func() {
|
||||
m.UpdateMaxStreamLimit(10000)
|
||||
_, err := m.OpenStream() // open stream 2
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
|
||||
err = m.DeleteStream(2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
callbackCalledForStream = callbackCalledForStream[:0]
|
||||
for i := 4; i <= 8; i++ {
|
||||
str := NewMockStreamI(mockCtrl)
|
||||
str.EXPECT().StreamID().Return(protocol.StreamID(i)).AnyTimes()
|
||||
err := m.putStream(str)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
})
|
||||
|
||||
It("ranges over all open streams", func() {
|
||||
m.Range(callback)
|
||||
Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
|
||||
})
|
||||
It("errors when the stream doesn't exist", func() {
|
||||
err := m.DeleteStream(1337)
|
||||
Expect(err).To(MatchError(errMapAccess))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user