Merge pull request #1081 from lucas-clemente/stream-completed-callback

immediately delete a stream when it is completed
This commit is contained in:
Marten Seemann
2018-01-03 11:33:19 +07:00
committed by GitHub
12 changed files with 199 additions and 423 deletions

View File

@@ -169,18 +169,6 @@ func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0)
}
// finished mocks base method
func (m *MockStreamI) finished() bool {
ret := m.ctrl.Call(m, "finished")
ret0, _ := ret[0].(bool)
return ret0
}
// finished indicates an expected call of finished
func (mr *MockStreamIMockRecorder) finished() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "finished", reflect.TypeOf((*MockStreamI)(nil).finished))
}
// getWindowUpdate mocks base method
func (m *MockStreamI) getWindowUpdate() protocol.ByteCount {
ret := m.ctrl.Call(m, "getWindowUpdate")

View File

@@ -55,6 +55,16 @@ func (mr *MockStreamSenderMockRecorder) onHasWindowUpdate(arg0 interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasWindowUpdate", reflect.TypeOf((*MockStreamSender)(nil).onHasWindowUpdate), arg0)
}
// onStreamCompleted mocks base method
func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) {
m.ctrl.Call(m, "onStreamCompleted", arg0)
}
// onStreamCompleted indicates an expected call of onStreamCompleted
func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0)
}
// queueControlFrame mocks base method
func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) {
m.ctrl.Call(m, "queueControlFrame", arg0)

View File

@@ -148,6 +148,7 @@ func (s *receiveStream) Read(p []byte) (int, error) {
s.frameQueue.Pop()
s.finRead = frame.FinBit
if frame.FinBit {
s.sender.onStreamCompleted(s.streamID)
return bytesRead, io.EOF
}
}
@@ -219,6 +220,7 @@ func (s *receiveStream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode),
}
s.signalRead()
s.sender.onStreamCompleted(s.streamID)
return nil
}
@@ -259,14 +261,6 @@ func (s *receiveStream) closeForShutdown(err error) {
s.signalRead()
}
func (s *receiveStream) finished() bool {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.closedForShutdown || // if the stream was abruptly closed for shutting down
s.finRead || s.resetRemotely
}
func (s *receiveStream) getWindowUpdate() protocol.ByteCount {
return s.flowController.GetWindowUpdate()
}

View File

@@ -33,7 +33,6 @@ var _ = Describe("Receive Stream", func() {
timeout := scaleDuration(250 * time.Millisecond)
strWithTimeout = gbytes.TimeoutReader(str, timeout)
strWithTimeout = str
})
It("gets stream id", func() {
@@ -320,12 +319,12 @@ var _ = Describe("Receive Stream", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(4), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
mockFC.EXPECT().HasWindowUpdate()
frame := wire.StreamFrame{
str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD, 0xBE, 0xEF},
FinBit: true,
}
str.handleStreamFrame(&frame)
})
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(io.EOF))
@@ -354,6 +353,7 @@ var _ = Describe("Receive Stream", func() {
Expect(err).ToNot(HaveOccurred())
err = str.handleStreamFrame(&frame2)
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(io.EOF))
@@ -368,31 +368,30 @@ var _ = Describe("Receive Stream", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(2), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
mockFC.EXPECT().HasWindowUpdate()
frame := wire.StreamFrame{
err := str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
Data: []byte{0xDE, 0xAD},
Data: []byte{0xde, 0xad},
FinBit: true,
}
err := str.handleStreamFrame(&frame)
})
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(2))
Expect(b[:n]).To(Equal([]byte{0xDE, 0xAD}))
Expect(b[:n]).To(Equal([]byte{0xde, 0xad}))
})
It("handles immediate FINs", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
mockFC.EXPECT().HasWindowUpdate()
frame := wire.StreamFrame{
err := str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
Data: []byte{},
FinBit: true,
}
err := str.handleStreamFrame(&frame)
})
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 4)
n, err := strWithTimeout.Read(b)
Expect(n).To(BeZero())
@@ -405,6 +404,7 @@ var _ = Describe("Receive Stream", func() {
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
mockFC.EXPECT().HasWindowUpdate()
str.CloseRemote(0)
mockSender.EXPECT().onStreamCompleted(streamID)
b := make([]byte, 8)
n, err := strWithTimeout.Read(b)
Expect(n).To(BeZero())
@@ -486,6 +486,7 @@ var _ = Describe("Receive Stream", func() {
FinBit: true,
})
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().onStreamCompleted(streamID)
_, err = strWithTimeout.Read(make([]byte, 100))
Expect(err).To(MatchError(io.EOF))
err = str.CancelRead(1234)
@@ -526,11 +527,13 @@ var _ = Describe("Receive Stream", func() {
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleRstStreamFrame(rst)
Eventually(done).Should(BeClosed())
})
It("doesn't allow further calls to Read", func() {
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
@@ -549,6 +552,7 @@ var _ = Describe("Receive Stream", func() {
})
It("ignores duplicate RST_STREAM frames", func() {
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
@@ -580,6 +584,7 @@ var _ = Describe("Receive Stream", func() {
close(readReturned)
}()
Consistently(readReturned).ShouldNot(BeClosed())
mockSender.EXPECT().onStreamCompleted(streamID)
err := str.handleRstStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
Eventually(readReturned).Should(BeClosed())
@@ -587,8 +592,11 @@ var _ = Describe("Receive Stream", func() {
It("continues reading until the end when receiving a RST_STREAM frame with error code 0", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true).Times(2)
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4))
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
gomock.InOrder(
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(4)),
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)),
mockSender.EXPECT().onStreamCompleted(streamID),
)
mockFC.EXPECT().HasWindowUpdate().Times(2)
readReturned := make(chan struct{})
go func() {
@@ -637,43 +645,4 @@ var _ = Describe("Receive Stream", func() {
Expect(str.getWindowUpdate()).To(Equal(protocol.ByteCount(0x100)))
})
})
Context("saying if it is finished", func() {
finishReading := func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true)
err := str.handleStreamFrame(&wire.StreamFrame{FinBit: true})
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 100)
_, err = strWithTimeout.Read(b)
ExpectWithOffset(0, err).To(MatchError(io.EOF))
}
It("is finished after it is closed for shutdown", func() {
str.closeForShutdown(errors.New("testErr"))
Expect(str.finished()).To(BeTrue())
})
It("is finished if it is only closed for reading", func() {
mockFC.EXPECT().AddBytesRead(protocol.ByteCount(0))
mockFC.EXPECT().HasWindowUpdate()
finishReading()
Expect(str.finished()).To(BeTrue())
})
// the stream still needs to stay alive until we receive the final offset
// (either by receiving a STREAM frame with FIN, or a RST_STREAM)
It("is not finished after CancelRead", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
err := str.CancelRead(123)
Expect(err).ToNot(HaveOccurred())
Expect(str.finished()).To(BeFalse())
})
It("is finished after receiving a RST_STREAM frame", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(3), true)
err := str.handleRstStreamFrame(&wire.RstStreamFrame{ByteOffset: 3})
Expect(err).ToNot(HaveOccurred())
Expect(str.finished()).To(BeTrue())
})
})
})

View File

@@ -154,6 +154,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFr
}
if frame.FinBit {
s.finSent = true
s.sender.onStreamCompleted(s.streamID)
} else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream
if isBlocked, offset := s.flowController.IsBlocked(); isBlocked {
s.sender.queueControlFrame(&wire.StreamBlockedFrame{
@@ -231,6 +232,7 @@ func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, wr
})
// TODO(#991): cancel retransmissions for this stream
s.ctxCancel()
s.sender.onStreamCompleted(s.streamID)
return nil
}
@@ -289,14 +291,6 @@ func (s *sendStream) closeForShutdown(err error) {
s.ctxCancel()
}
func (s *sendStream) finished() bool {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.closedForShutdown || // if the stream was abruptly closed for shutting down
s.finSent || s.canceledWrite
}
func (s *sendStream) getWriteOffset() protocol.ByteCount {
return s.writeOffset
}

View File

@@ -247,6 +247,7 @@ var _ = Describe("Send Stream", func() {
It("doesn't queue a BLOCKED frame if the stream is flow control blocked, but the frame popped has the FIN bit set", func() {
mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for the Write, once for the Close
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999))
mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6))
// don't EXPECT a call to mockFC.IsBlocked
@@ -389,6 +390,7 @@ var _ = Describe("Send Stream", func() {
It("allows FIN", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onStreamCompleted(streamID)
str.Close()
f, hasMoreData := str.popStreamFrame(1000)
Expect(f).ToNot(BeNil())
@@ -409,6 +411,7 @@ var _ = Describe("Send Stream", func() {
Expect(f).ToNot(BeNil())
Expect(f.Data).To(Equal([]byte("foo")))
Expect(f.FinBit).To(BeFalse())
mockSender.EXPECT().onStreamCompleted(streamID)
f, _ = str.popStreamFrame(100)
Expect(f.Data).To(Equal([]byte("bar")))
Expect(f.FinBit).To(BeTrue())
@@ -423,6 +426,7 @@ var _ = Describe("Send Stream", func() {
It("doesn't allow FIN twice", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onStreamCompleted(streamID)
str.Close()
f, _ := str.popStreamFrame(1000)
Expect(f).ToNot(BeNil())
@@ -513,6 +517,7 @@ var _ = Describe("Send Stream", func() {
ByteOffset: 1234,
ErrorCode: 9876,
})
mockSender.EXPECT().onStreamCompleted(streamID)
str.writeOffset = 1234
err := str.CancelWrite(9876)
Expect(err).ToNot(HaveOccurred())
@@ -520,6 +525,7 @@ var _ = Describe("Send Stream", func() {
It("unblocks Write", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().onStreamCompleted(streamID)
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount)
mockFC.EXPECT().AddBytesSent(gomock.Any())
@@ -544,6 +550,7 @@ var _ = Describe("Send Stream", func() {
It("cancels the context", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
Expect(str.Context().Done()).ToNot(BeClosed())
str.CancelWrite(1234)
Expect(str.Context().Done()).To(BeClosed())
@@ -551,6 +558,7 @@ var _ = Describe("Send Stream", func() {
It("doesn't allow further calls to Write", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
err := str.CancelWrite(1234)
Expect(err).ToNot(HaveOccurred())
_, err = strWithTimeout.Write([]byte("foobar"))
@@ -559,6 +567,7 @@ var _ = Describe("Send Stream", func() {
It("only cancels once", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
err := str.CancelWrite(1234)
Expect(err).ToNot(HaveOccurred())
err = str.CancelWrite(4321)
@@ -580,6 +589,7 @@ var _ = Describe("Send Stream", func() {
StreamID: streamID,
ErrorCode: errorCodeStopping,
})
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 101,
@@ -600,6 +610,7 @@ var _ = Describe("Send Stream", func() {
close(done)
}()
waitForWrite()
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 123,
@@ -609,6 +620,7 @@ var _ = Describe("Send Stream", func() {
It("doesn't allow further calls to Write", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 123,
@@ -621,32 +633,4 @@ var _ = Describe("Send Stream", func() {
})
})
})
Context("saying if it is finished", func() {
It("is finished after it is closed for shutdown", func() {
str.closeForShutdown(errors.New("testErr"))
Expect(str.finished()).To(BeTrue())
})
It("is finished after Close()", func() {
mockSender.EXPECT().onHasStreamData(streamID)
str.Close()
f, _ := str.popStreamFrame(1000)
Expect(f.FinBit).To(BeTrue())
Expect(str.finished()).To(BeTrue())
})
It("is finished after CancelWrite", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
err := str.CancelWrite(123)
Expect(err).ToNot(HaveOccurred())
Expect(str.finished()).To(BeTrue())
})
It("is finished after receiving a STOP_SENDING (and sending a RST_STREAM)", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID})
Expect(str.finished()).To(BeTrue())
})
})
})

View File

@@ -408,10 +408,6 @@ runLoop:
if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout {
s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
}
if err := s.streamsMap.DeleteClosedStreams(); err != nil {
s.closeLocal(err)
}
}
// only send the error the handshakeChan when the handshake is not completed yet
@@ -950,6 +946,12 @@ func (s *session) onHasStreamData(id protocol.StreamID) {
s.scheduleSending()
}
func (s *session) onStreamCompleted(id protocol.StreamID) {
if err := s.streamsMap.DeleteStream(id); err != nil {
s.Close(err)
}
}
func (s *session) LocalAddr() net.Addr {
return s.conn.LocalAddr()
}

View File

@@ -122,7 +122,7 @@ var _ = Describe("Session", func() {
)
Expect(err).NotTo(HaveOccurred())
sess = pSess.(*session)
Expect(sess.streamsMap.openStreams).To(BeEmpty())
Expect(sess.streamsMap.streams).To(BeEmpty())
})
AfterEach(func() {
@@ -195,9 +195,6 @@ var _ = Describe("Session", func() {
sess.streamsMap.newStream = func(id protocol.StreamID) streamI {
str := NewMockStreamI(mockCtrl)
str.EXPECT().StreamID().Return(id).AnyTimes()
if id == 1 {
str.EXPECT().finished().AnyTimes()
}
return str
}
})
@@ -247,9 +244,9 @@ var _ = Describe("Session", func() {
return str
}
sess.handleStreamFrame(f1)
numOpenStreams := len(sess.streamsMap.openStreams)
numOpenStreams := len(sess.streamsMap.streams)
sess.handleStreamFrame(f2)
Expect(sess.streamsMap.openStreams).To(HaveLen(numOpenStreams))
Expect(sess.streamsMap.streams).To(HaveLen(numOpenStreams))
})
It("ignores STREAM frames for closed streams", func() {
@@ -329,8 +326,7 @@ var _ = Describe("Session", func() {
It("ignores the error when the stream is not known", func() {
str, err := sess.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
str.(*MockStreamI).EXPECT().finished().Return(true)
sess.streamsMap.DeleteClosedStreams()
sess.onStreamCompleted(3)
str, err = sess.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
@@ -411,9 +407,7 @@ var _ = Describe("Session", func() {
It("ignores MAX_STREAM_DATA frames for a closed stream", func() {
str, err := sess.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
str.(*MockStreamI).EXPECT().finished().Return(true)
err = sess.streamsMap.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
sess.onStreamCompleted(3)
str, err = sess.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
@@ -457,8 +451,7 @@ var _ = Describe("Session", func() {
It("ignores STOP_SENDING frames for a closed stream", func() {
str, err := sess.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
str.(*MockStreamI).EXPECT().finished().Return(true)
err = sess.streamsMap.DeleteClosedStreams()
sess.onStreamCompleted(3)
Expect(err).ToNot(HaveOccurred())
str, err = sess.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
@@ -1403,12 +1396,7 @@ var _ = Describe("Session", func() {
It("returns a nil-value (not an interface with value nil) for closed streams", func() {
str, err := sess.GetOrOpenStream(9)
Expect(err).ToNot(HaveOccurred())
str.Close()
str.(*stream).closeForShutdown(nil)
Expect(str.(*stream).finished()).To(BeTrue())
err = sess.streamsMap.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(sess.streamsMap.GetOrOpenStream(9)).To(BeNil())
sess.onStreamCompleted(9)
str, err = sess.GetOrOpenStream(9)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
@@ -1425,31 +1413,6 @@ var _ = Describe("Session", func() {
})
})
Context("counting streams", func() {
It("errors when too many streams are opened", func() {
for i := 0; i < protocol.MaxIncomingStreams; i++ {
_, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1))
Expect(err).NotTo(HaveOccurred())
}
_, err := sess.GetOrOpenStream(protocol.StreamID(301))
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("does not error when many streams are opened and closed", func() {
for i := 2; i <= 1000; i++ {
s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1))
Expect(err).NotTo(HaveOccurred())
Expect(s.Close()).To(Succeed())
f, _ := s.(*stream).popStreamFrame(1000) // trigger "sending" of the FIN bit
Expect(f.FinBit).To(BeTrue())
s.(*stream).CloseRemote(0)
_, err = s.Read([]byte("a"))
Expect(err).To(MatchError(io.EOF))
sess.streamsMap.DeleteClosedStreams()
}
})
})
Context("ignoring errors", func() {
It("ignores duplicate acks", func() {
sess.sentPacketHandler.SentPacket(&ackhandler.Packet{
@@ -1522,7 +1485,7 @@ var _ = Describe("Client Session", func() {
)
sess = sessP.(*session)
Expect(err).ToNot(HaveOccurred())
Expect(sess.streamsMap.openStreams).To(BeEmpty())
Expect(sess.streamsMap.streams).To(BeEmpty())
})
AfterEach(func() {

View File

@@ -2,6 +2,7 @@ package quic
import (
"net"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
@@ -19,8 +20,34 @@ type streamSender interface {
queueControlFrame(wire.Frame)
onHasWindowUpdate(protocol.StreamID)
onHasStreamData(protocol.StreamID)
onStreamCompleted(protocol.StreamID)
}
// Each of the both stream halves gets its own uniStreamSender.
// This is necessary in order to keep track when both halves have been completed.
type uniStreamSender struct {
streamSender
onStreamCompletedImpl func()
}
func (s *uniStreamSender) queueControlFrame(f wire.Frame) {
s.streamSender.queueControlFrame(f)
}
func (s *uniStreamSender) onHasWindowUpdate(id protocol.StreamID) {
s.streamSender.onHasWindowUpdate(id)
}
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) {
s.streamSender.onHasStreamData(id)
}
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) {
s.onStreamCompletedImpl()
}
var _ streamSender = &uniStreamSender{}
type streamI interface {
Stream
@@ -28,7 +55,6 @@ type streamI interface {
handleRstStreamFrame(*wire.RstStreamFrame) error
handleStopSendingFrame(*wire.StopSendingFrame)
popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool)
finished() bool
closeForShutdown(error)
// methods needed for flow control
getWindowUpdate() protocol.ByteCount
@@ -42,6 +68,11 @@ type stream struct {
receiveStream
sendStream
completedMutex sync.Mutex
sender streamSender
receiveStreamCompleted bool
sendStreamCompleted bool
version protocol.VersionNumber
}
@@ -72,10 +103,28 @@ func newStream(streamID protocol.StreamID,
flowController flowcontrol.StreamFlowController,
version protocol.VersionNumber,
) *stream {
return &stream{
sendStream: *newSendStream(streamID, sender, flowController, version),
receiveStream: *newReceiveStream(streamID, sender, flowController),
s := &stream{sender: sender}
senderForSendStream := &uniStreamSender{
streamSender: sender,
onStreamCompletedImpl: func() {
s.completedMutex.Lock()
s.sendStreamCompleted = true
s.checkIfCompleted()
s.completedMutex.Unlock()
},
}
s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version)
senderForReceiveStream := &uniStreamSender{
streamSender: sender,
onStreamCompletedImpl: func() {
s.completedMutex.Lock()
s.receiveStreamCompleted = true
s.checkIfCompleted()
s.completedMutex.Unlock()
},
}
s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController)
return s
}
// need to define StreamID() here, since both receiveStream and readStream have a StreamID()
@@ -120,6 +169,10 @@ func (s *stream) handleRstStreamFrame(frame *wire.RstStreamFrame) error {
return nil
}
func (s *stream) finished() bool {
return s.sendStream.finished() && s.receiveStream.finished()
// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed.
// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed.
func (s *stream) checkIfCompleted() {
if s.sendStreamCompleted && s.receiveStreamCompleted {
s.sender.onStreamCompleted(s.StreamID())
}
}

View File

@@ -1,7 +1,6 @@
package quic
import (
"errors"
"io"
"os"
"strconv"
@@ -72,6 +71,7 @@ var _ = Describe("Stream", func() {
ByteOffset: 1000,
ErrorCode: errorCodeStoppingGQUIC,
})
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true)
str.writeOffset = 1000
f := &wire.RstStreamFrame{
@@ -189,26 +189,21 @@ var _ = Describe("Stream", func() {
})
})
Context("saying if it is finished", func() {
It("is finished when both the send and the receive side are finished", func() {
str.receiveStream.closeForShutdown(errors.New("shutdown"))
Expect(str.receiveStream.finished()).To(BeTrue())
Expect(str.sendStream.finished()).To(BeFalse())
Expect(str.finished()).To(BeFalse())
Context("completing", func() {
It("is not completed when only the receive side is completed", func() {
// don't EXPECT a call to mockSender.onStreamCompleted()
str.receiveStream.sender.onStreamCompleted(streamID)
})
It("is not finished when the receive side is finished", func() {
str.sendStream.closeForShutdown(errors.New("shutdown"))
Expect(str.receiveStream.finished()).To(BeFalse())
Expect(str.sendStream.finished()).To(BeTrue())
Expect(str.finished()).To(BeFalse())
It("is not completed when only the send side is completed", func() {
// don't EXPECT a call to mockSender.onStreamCompleted()
str.sendStream.sender.onStreamCompleted(streamID)
})
It("is not finished when the send side is finished", func() {
str.closeForShutdown(errors.New("shutdown"))
Expect(str.receiveStream.finished()).To(BeTrue())
Expect(str.sendStream.finished()).To(BeTrue())
Expect(str.finished()).To(BeTrue())
It("is completed when both sides are completed", func() {
mockSender.EXPECT().onStreamCompleted(streamID)
str.sendStream.sender.onStreamCompleted(streamID)
str.receiveStream.sender.onStreamCompleted(streamID)
})
})
})

View File

@@ -16,9 +16,6 @@ type streamsMap struct {
perspective protocol.Perspective
streams map[protocol.StreamID]streamI
// needed for round-robin scheduling
openStreams []protocol.StreamID
roundRobinIndex int
nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID
@@ -51,7 +48,6 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver pro
sm := streamsMap{
perspective: pers,
streams: make(map[protocol.StreamID]streamI),
openStreams: make([]protocol.StreamID, 0),
newStream: newStream,
maxIncomingStreams: maxIncomingStreams,
}
@@ -99,7 +95,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
s, ok := m.streams[id]
m.mutex.RUnlock()
if ok {
return s, nil // s may be nil
return s, nil
}
// ... we don't have an existing stream
@@ -212,48 +208,19 @@ func (m *streamsMap) AcceptStream() (streamI, error) {
return str, nil
}
func (m *streamsMap) DeleteClosedStreams() error {
func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
m.mutex.Lock()
defer m.mutex.Unlock()
var numDeletedStreams int
// for every closed stream, the streamID is replaced by 0 in the openStreams slice
for i, streamID := range m.openStreams {
str, ok := m.streams[streamID]
if !ok {
return errMapAccess
}
if !str.finished() {
continue
}
numDeletedStreams++
m.openStreams[i] = 0
if m.streamInitiatedBy(streamID) == m.perspective {
m.numOutgoingStreams--
} else {
m.numIncomingStreams--
}
delete(m.streams, streamID)
_, ok := m.streams[id]
if !ok {
return errMapAccess
}
if numDeletedStreams == 0 {
return nil
delete(m.streams, id)
if m.streamInitiatedBy(id) == m.perspective {
m.numOutgoingStreams--
} else {
m.numIncomingStreams--
}
// remove all 0s (representing closed streams) from the openStreams slice
// and adjust the roundRobinIndex
var j int
for i, id := range m.openStreams {
if i != j {
m.openStreams[j] = m.openStreams[i]
}
if id != 0 {
j++
} else if j < m.roundRobinIndex {
m.roundRobinIndex--
}
}
m.openStreams = m.openStreams[:len(m.openStreams)-numDeletedStreams]
m.openStreamOrErrCond.Signal()
return nil
}
@@ -264,28 +231,16 @@ func (m *streamsMap) Range(cb func(s streamI)) {
defer m.mutex.RUnlock()
for _, s := range m.streams {
if s != nil {
cb(s)
}
cb(s)
}
}
func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) {
str, ok := m.streams[streamID]
if !ok {
return true, errMapAccess
}
return fn(str)
}
func (m *streamsMap) putStream(s streamI) error {
id := s.StreamID()
if _, ok := m.streams[id]; ok {
return fmt.Errorf("a stream with ID %d already exists", id)
}
m.streams[id] = s
m.openStreams = append(m.openStreams, id)
return nil
}
@@ -295,8 +250,8 @@ func (m *streamsMap) CloseWithError(err error) {
m.closeErr = err
m.nextStreamOrErrCond.Broadcast()
m.openStreamOrErrCond.Broadcast()
for _, s := range m.openStreams {
m.streams[s].closeForShutdown(err)
for _, s := range m.streams {
s.closeForShutdown(err)
}
}

View File

@@ -7,22 +7,16 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Streams Map", func() {
var (
m *streamsMap
finishedStreams map[protocol.StreamID]*gomock.Call
)
var m *streamsMap
newStream := func(id protocol.StreamID) streamI {
str := NewMockStreamI(mockCtrl)
str.EXPECT().StreamID().Return(id).AnyTimes()
c := str.EXPECT().finished().Return(false).AnyTimes()
finishedStreams[id] = c
return str
}
@@ -30,20 +24,8 @@ var _ = Describe("Streams Map", func() {
m = newStreamsMap(newStream, p, v)
}
BeforeEach(func() {
finishedStreams = make(map[protocol.StreamID]*gomock.Call)
})
AfterEach(func() {
Expect(m.openStreams).To(HaveLen(len(m.streams)))
})
deleteStream := func(id protocol.StreamID) {
str := m.streams[id]
Expect(str).ToNot(BeNil())
finishedStreams[id].Return(true)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, m.DeleteStream(id)).To(Succeed())
}
Context("getting and creating streams", func() {
@@ -521,176 +503,63 @@ var _ = Describe("Streams Map", func() {
})
})
Context("DoS mitigation, iterating and deleting", func() {
Context("Ranging", func() {
It("ranges over all open streams", func() {
setNewStreamsMap(protocol.PerspectiveServer, protocol.VersionWhatever)
var callbackCalledForStream []protocol.StreamID
callback := func(str streamI) {
callbackCalledForStream = append(callbackCalledForStream, str.StreamID())
sort.Slice(callbackCalledForStream, func(i, j int) bool {
return callbackCalledForStream[i] < callbackCalledForStream[j]
})
}
Expect(m.streams).To(BeEmpty())
// create 5 streams, ids 4 to 8
callbackCalledForStream = callbackCalledForStream[:0]
for i := 4; i <= 8; i++ {
str := NewMockStreamI(mockCtrl)
str.EXPECT().StreamID().Return(protocol.StreamID(i)).AnyTimes()
err := m.putStream(str)
Expect(err).NotTo(HaveOccurred())
}
// execute the callback for all streams
m.Range(callback)
Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
})
})
Context("deleting streams", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames)
})
closeStream := func(id protocol.StreamID) {
str := m.streams[id]
ExpectWithOffset(1, str).ToNot(BeNil())
finishedStreams[id].Return(true)
}
Context("deleting streams", func() {
Context("as a server", func() {
BeforeEach(func() {
m.UpdateMaxStreamLimit(100)
for i := 1; i <= 5; i++ {
if i%2 == 1 {
_, err := m.openRemoteStream(protocol.StreamID(i))
Expect(err).ToNot(HaveOccurred())
} else {
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
}
}
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) // 2 and 4
Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) // 1, 3 and 5
})
It("does not delete streams with Close()", func() {
str, err := m.GetOrOpenStream(55)
Expect(err).ToNot(HaveOccurred())
str.(*MockStreamI).EXPECT().Close()
str.Close()
err = m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
str, err = m.GetOrOpenStream(55)
Expect(err).ToNot(HaveOccurred())
Expect(str).ToNot(BeNil())
})
It("removes the first stream", func() {
closeStream(1)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes a stream in the middle", func() {
closeStream(3)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes a client-initiated stream", func() {
closeStream(2)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 3, 4, 5}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
Expect(m.numIncomingStreams).To(BeEquivalentTo(3))
})
It("removes a stream at the end", func() {
closeStream(5)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes all streams", func() {
for i := 1; i <= 5; i++ {
closeStream(protocol.StreamID(i))
}
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(BeEmpty())
Expect(m.openStreams).To(BeEmpty())
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeZero())
})
})
Context("as a client", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveClient, versionGQUICFrames)
m.UpdateMaxStreamLimit(100)
for i := 1; i <= 5; i++ {
if i%2 == 0 {
_, err := m.openRemoteStream(protocol.StreamID(i))
Expect(err).ToNot(HaveOccurred())
} else {
_, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
}
}
Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 2, 5, 4, 7}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(3)) // 3, 5 and 7
Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) // 2 and 4
})
It("removes a stream that we initiated", func() {
closeStream(3)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 5, 4, 7}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
})
It("removes a stream that the server initiated", func() {
closeStream(2)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.openStreams).To(HaveLen(4))
Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 5, 4, 7}))
Expect(m.numOutgoingStreams).To(BeEquivalentTo(3))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
})
It("removes all streams", func() {
closeStream(3)
closeStream(2)
closeStream(5)
closeStream(4)
closeStream(7)
err := m.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(BeEmpty())
Expect(m.openStreams).To(BeEmpty())
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeZero())
})
})
It("deletes an incoming stream", func() {
_, err := m.GetOrOpenStream(5) // open stream 3 and 5
Expect(err).ToNot(HaveOccurred())
Expect(m.numIncomingStreams).To(BeEquivalentTo(2))
err = m.DeleteStream(3)
Expect(err).ToNot(HaveOccurred())
Expect(m.streams).To(HaveLen(1))
Expect(m.streams).To(HaveKey(protocol.StreamID(5)))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
})
Context("Ranging", func() {
// create 5 streams, ids 4 to 8
var callbackCalledForStream []protocol.StreamID
callback := func(str streamI) {
callbackCalledForStream = append(callbackCalledForStream, str.StreamID())
sort.Slice(callbackCalledForStream, func(i, j int) bool { return callbackCalledForStream[i] < callbackCalledForStream[j] })
}
It("deletes an outgoing stream", func() {
m.UpdateMaxStreamLimit(10000)
_, err := m.OpenStream() // open stream 2
Expect(err).ToNot(HaveOccurred())
_, err = m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(m.numOutgoingStreams).To(BeEquivalentTo(2))
err = m.DeleteStream(2)
Expect(err).ToNot(HaveOccurred())
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
})
BeforeEach(func() {
callbackCalledForStream = callbackCalledForStream[:0]
for i := 4; i <= 8; i++ {
str := NewMockStreamI(mockCtrl)
str.EXPECT().StreamID().Return(protocol.StreamID(i)).AnyTimes()
err := m.putStream(str)
Expect(err).NotTo(HaveOccurred())
}
})
It("ranges over all open streams", func() {
m.Range(callback)
Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8}))
})
It("errors when the stream doesn't exist", func() {
err := m.DeleteStream(1337)
Expect(err).To(MatchError(errMapAccess))
})
})
})