From a20e94ee165c36790d3aa6919794975a4ba6e0fe Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 4 Jan 2018 10:26:02 +0700 Subject: [PATCH] use separate streamsMaps for gQUIC and IETF QUIC This is a lot of duplicate code for now, but it will make moving towards the new stream ID mapping in IETF QUIC (and unidirectional streams) much easier. --- session.go | 18 +- streams_map.go | 22 +- streams_map_legacy.go | 257 +++++++++++++++++ streams_map_legacy_test.go | 549 +++++++++++++++++++++++++++++++++++++ streams_map_test.go | 116 ++++---- 5 files changed, 873 insertions(+), 89 deletions(-) create mode 100644 streams_map_legacy.go create mode 100644 streams_map_legacy_test.go diff --git a/session.go b/session.go index 48200ad3..1ed80132 100644 --- a/session.go +++ b/session.go @@ -28,6 +28,18 @@ type streamGetter interface { GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) } +type streamManager interface { + GetOrOpenStream(protocol.StreamID) (streamI, error) + GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) + GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) + OpenStream() (Stream, error) + OpenStreamSync() (Stream, error) + AcceptStream() (Stream, error) + DeleteStream(protocol.StreamID) error + UpdateLimits(*handshake.TransportParameters) + CloseWithError(error) +} + type receivedPacket struct { remoteAddr net.Addr header *wire.Header @@ -310,7 +322,11 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) - s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version) + if s.version.UsesTLS() { + s.streamsMap = newStreamsMap(s.newStream, s.perspective) + } else { + s.streamsMap = newStreamsMapLegacy(s.newStream, s.perspective) + } s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) s.packer = newPacketPacker(s.connectionID, initialPacketNumber, diff --git a/streams_map.go b/streams_map.go index 496f5dd3..86d945fc 100644 --- a/streams_map.go +++ b/streams_map.go @@ -12,18 +12,6 @@ import ( "github.com/lucas-clemente/quic-go/qerr" ) -type streamManager interface { - GetOrOpenStream(protocol.StreamID) (streamI, error) - GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) - GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) - OpenStream() (Stream, error) - OpenStreamSync() (Stream, error) - AcceptStream() (Stream, error) - DeleteStream(protocol.StreamID) error - UpdateLimits(*handshake.TransportParameters) - CloseWithError(error) -} - type streamsMap struct { mutex sync.RWMutex @@ -49,12 +37,11 @@ type streamsMap struct { var _ streamManager = &streamsMap{} -type streamLambda func(streamI) (bool, error) type newStreamLambda func(protocol.StreamID) streamI var errMapAccess = errors.New("streamsMap: Error accessing the streams map") -func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver protocol.VersionNumber) *streamsMap { +func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) streamManager { // add some tolerance to the maximum incoming streams value maxStreams := uint32(protocol.MaxIncomingStreams) maxIncomingStreams := utils.MaxUint32( @@ -72,13 +59,6 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver pro nextClientInitiatedStream := protocol.StreamID(1) nextServerInitiatedStream := protocol.StreamID(2) - if !ver.UsesTLS() { - nextServerInitiatedStream = 2 - nextClientInitiatedStream = 3 - if pers == protocol.PerspectiveServer { - sm.highestStreamOpenedByPeer = 1 - } - } if pers == protocol.PerspectiveServer { sm.nextStreamToOpen = nextServerInitiatedStream sm.nextStreamToAccept = nextClientInitiatedStream diff --git a/streams_map_legacy.go b/streams_map_legacy.go new file mode 100644 index 00000000..de52a8f6 --- /dev/null +++ b/streams_map_legacy.go @@ -0,0 +1,257 @@ +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" +) + +type streamsMapLegacy struct { + mutex sync.RWMutex + + perspective protocol.Perspective + + streams map[protocol.StreamID]streamI + + nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() + highestStreamOpenedByPeer protocol.StreamID + nextStreamOrErrCond sync.Cond + openStreamOrErrCond sync.Cond + + closeErr error + nextStreamToAccept protocol.StreamID + + newStream newStreamLambda + + numOutgoingStreams uint32 + numIncomingStreams uint32 + maxIncomingStreams uint32 + maxOutgoingStreams uint32 +} + +var _ streamManager = &streamsMapLegacy{} + +func newStreamsMapLegacy(newStream newStreamLambda, pers protocol.Perspective) streamManager { + // add some tolerance to the maximum incoming streams value + maxStreams := uint32(protocol.MaxIncomingStreams) + maxIncomingStreams := utils.MaxUint32( + maxStreams+protocol.MaxStreamsMinimumIncrement, + uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)), + ) + sm := streamsMapLegacy{ + perspective: pers, + streams: make(map[protocol.StreamID]streamI), + newStream: newStream, + maxIncomingStreams: maxIncomingStreams, + } + sm.nextStreamOrErrCond.L = &sm.mutex + sm.openStreamOrErrCond.L = &sm.mutex + + nextServerInitiatedStream := protocol.StreamID(2) + nextClientInitiatedStream := protocol.StreamID(3) + if pers == protocol.PerspectiveServer { + sm.highestStreamOpenedByPeer = 1 + } + if pers == protocol.PerspectiveServer { + sm.nextStreamToOpen = nextServerInitiatedStream + sm.nextStreamToAccept = nextClientInitiatedStream + } else { + sm.nextStreamToOpen = nextClientInitiatedStream + sm.nextStreamToAccept = nextServerInitiatedStream + } + return &sm +} + +// getStreamPerspective says which side should initiate a stream +func (m *streamsMapLegacy) streamInitiatedBy(id protocol.StreamID) protocol.Perspective { + if id%2 == 0 { + return protocol.PerspectiveServer + } + return protocol.PerspectiveClient +} + +func (m *streamsMapLegacy) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { + // every bidirectional stream is also a receive stream + return m.GetOrOpenStream(id) +} + +func (m *streamsMapLegacy) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { + // every bidirectional stream is also a send stream + return m.GetOrOpenStream(id) +} + +// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. +// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. +func (m *streamsMapLegacy) GetOrOpenStream(id protocol.StreamID) (streamI, error) { + m.mutex.RLock() + s, ok := m.streams[id] + m.mutex.RUnlock() + if ok { + return s, nil + } + + // ... we don't have an existing stream + m.mutex.Lock() + defer m.mutex.Unlock() + // We need to check whether another invocation has already created a stream (between RUnlock() and Lock()). + s, ok = m.streams[id] + if ok { + return s, nil + } + + if m.perspective == m.streamInitiatedBy(id) { + if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already + return nil, nil + } + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) + } + if id <= m.highestStreamOpenedByPeer { // this is a peer-initiated stream that doesn't exist anymore. Must have been closed already + return nil, nil + } + + for sid := m.highestStreamOpenedByPeer + 2; sid <= id; sid += 2 { + if _, err := m.openRemoteStream(sid); err != nil { + return nil, err + } + } + + m.nextStreamOrErrCond.Broadcast() + return m.streams[id], nil +} + +func (m *streamsMapLegacy) openRemoteStream(id protocol.StreamID) (streamI, error) { + if m.numIncomingStreams >= m.maxIncomingStreams { + return nil, qerr.TooManyOpenStreams + } + if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer)) + } + + m.numIncomingStreams++ + if id > m.highestStreamOpenedByPeer { + m.highestStreamOpenedByPeer = id + } + + s := m.newStream(id) + m.putStream(s) + return s, nil +} + +func (m *streamsMapLegacy) openStreamImpl() (streamI, error) { + if m.numOutgoingStreams >= m.maxOutgoingStreams { + return nil, qerr.TooManyOpenStreams + } + + m.numOutgoingStreams++ + s := m.newStream(m.nextStreamToOpen) + m.putStream(s) + m.nextStreamToOpen += 2 + return s, nil +} + +// OpenStream opens the next available stream +func (m *streamsMapLegacy) OpenStream() (Stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + return m.openStreamImpl() +} + +func (m *streamsMapLegacy) OpenStreamSync() (Stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + if m.closeErr != nil { + return nil, m.closeErr + } + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.openStreamOrErrCond.Wait() + } +} + +// AcceptStream returns the next stream opened by the peer +// it blocks until a new stream is opened +func (m *streamsMapLegacy) AcceptStream() (Stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + var str streamI + for { + var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } + str, ok = m.streams[m.nextStreamToAccept] + if ok { + break + } + m.nextStreamOrErrCond.Wait() + } + m.nextStreamToAccept += 2 + return str, nil +} + +func (m *streamsMapLegacy) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + _, ok := m.streams[id] + if !ok { + return errMapAccess + } + delete(m.streams, id) + if m.streamInitiatedBy(id) == m.perspective { + m.numOutgoingStreams-- + } else { + m.numIncomingStreams-- + } + m.openStreamOrErrCond.Signal() + return nil +} + +func (m *streamsMapLegacy) putStream(s streamI) error { + id := s.StreamID() + if _, ok := m.streams[id]; ok { + return fmt.Errorf("a stream with ID %d already exists", id) + } + m.streams[id] = s + return nil +} + +func (m *streamsMapLegacy) CloseWithError(err error) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.closeErr = err + m.nextStreamOrErrCond.Broadcast() + m.openStreamOrErrCond.Broadcast() + for _, s := range m.streams { + s.closeForShutdown(err) + } +} + +// TODO(#952): this won't be needed when gQUIC supports stateless handshakes +func (m *streamsMapLegacy) UpdateLimits(params *handshake.TransportParameters) { + m.mutex.Lock() + m.maxOutgoingStreams = params.MaxStreams + for id, str := range m.streams { + str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: id, + ByteOffset: params.StreamFlowControlWindow, + }) + } + m.mutex.Unlock() + m.openStreamOrErrCond.Broadcast() +} diff --git a/streams_map_legacy_test.go b/streams_map_legacy_test.go new file mode 100644 index 00000000..7e9324f8 --- /dev/null +++ b/streams_map_legacy_test.go @@ -0,0 +1,549 @@ +package quic + +import ( + "errors" + + "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Streams Map (for gQUIC)", func() { + var m *streamsMapLegacy + + newStream := func(id protocol.StreamID) streamI { + str := NewMockStreamI(mockCtrl) + str.EXPECT().StreamID().Return(id).AnyTimes() + return str + } + + setNewStreamsMap := func(p protocol.Perspective) { + m = newStreamsMapLegacy(newStream, p).(*streamsMapLegacy) + } + + deleteStream := func(id protocol.StreamID) { + ExpectWithOffset(1, m.DeleteStream(id)).To(Succeed()) + } + + Context("getting and creating streams", func() { + Context("as a server", func() { + BeforeEach(func() { + setNewStreamsMap(protocol.PerspectiveServer) + }) + + Context("client-side streams", func() { + It("gets new streams", func() { + s, err := m.GetOrOpenStream(3) + Expect(err).NotTo(HaveOccurred()) + Expect(s).ToNot(BeNil()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(3))) + Expect(m.streams).To(HaveLen(1)) + Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) + Expect(m.numOutgoingStreams).To(BeZero()) + }) + + It("rejects streams with even IDs", func() { + _, err := m.GetOrOpenStream(6) + Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 6")) + }) + + It("rejects streams with even IDs, which are lower thatn the highest client-side stream", func() { + _, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + _, err = m.GetOrOpenStream(4) + Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 4")) + }) + + It("gets existing streams", func() { + s, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + numStreams := m.numIncomingStreams + s, err = m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(5))) + Expect(m.numIncomingStreams).To(Equal(numStreams)) + }) + + It("returns nil for closed streams", func() { + _, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + deleteStream(5) + s, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + Expect(s).To(BeNil()) + }) + + It("opens skipped streams", func() { + _, err := m.GetOrOpenStream(7) + Expect(err).NotTo(HaveOccurred()) + Expect(m.streams).To(HaveKey(protocol.StreamID(3))) + Expect(m.streams).To(HaveKey(protocol.StreamID(5))) + Expect(m.streams).To(HaveKey(protocol.StreamID(7))) + }) + + It("doesn't reopen an already closed stream", func() { + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + deleteStream(5) + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) + + Context("counting streams", func() { + It("errors when too many streams are opened", func() { + for i := uint32(0); i < m.maxIncomingStreams; i++ { + _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) + Expect(err).NotTo(HaveOccurred()) + } + _, err := m.GetOrOpenStream(protocol.StreamID(2*m.maxIncomingStreams + 3)) + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + + It("errors when too many streams are opened implicitely", func() { + _, err := m.GetOrOpenStream(protocol.StreamID(m.maxIncomingStreams*2 + 3)) + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + + It("does not error when many streams are opened and closed", func() { + for i := uint32(2); i < 10*m.maxIncomingStreams; i++ { + str, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) + Expect(err).NotTo(HaveOccurred()) + deleteStream(str.StreamID()) + } + }) + }) + }) + + Context("server-side streams", func() { + It("doesn't allow opening streams before receiving the transport parameters", func() { + _, err := m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + + It("opens a stream 2 first", func() { + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) + s, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(s).ToNot(BeNil()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(2))) + Expect(m.numIncomingStreams).To(BeZero()) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) + }) + + It("returns the error when the streamsMap was closed", func() { + testErr := errors.New("test error") + m.CloseWithError(testErr) + _, err := m.OpenStream() + Expect(err).To(MatchError(testErr)) + }) + + It("doesn't reopen an already closed stream", func() { + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(2))) + deleteStream(2) + Expect(err).ToNot(HaveOccurred()) + str, err = m.GetOrOpenStream(2) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) + + Context("counting streams", func() { + const maxOutgoingStreams = 50 + + BeforeEach(func() { + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: maxOutgoingStreams}) + }) + + It("errors when too many streams are opened", func() { + for i := 1; i <= maxOutgoingStreams; i++ { + _, err := m.OpenStream() + Expect(err).NotTo(HaveOccurred()) + } + _, err := m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + + It("does not error when many streams are opened and closed", func() { + for i := 2; i < 10*maxOutgoingStreams; i++ { + str, err := m.OpenStream() + Expect(err).NotTo(HaveOccurred()) + deleteStream(str.StreamID()) + } + }) + + It("allows many server- and client-side streams at the same time", func() { + for i := 1; i < maxOutgoingStreams; i++ { + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + } + for i := 0; i < maxOutgoingStreams; i++ { + _, err := m.GetOrOpenStream(protocol.StreamID(2*i + 1)) + Expect(err).ToNot(HaveOccurred()) + } + }) + }) + + Context("opening streams synchronously", func() { + const maxOutgoingStreams = 10 + + BeforeEach(func() { + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: maxOutgoingStreams}) + }) + + openMaxNumStreams := func() { + for i := 1; i <= maxOutgoingStreams; i++ { + _, err := m.OpenStream() + Expect(err).NotTo(HaveOccurred()) + } + _, err := m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + } + + It("waits until another stream is closed", func() { + openMaxNumStreams() + var str Stream + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str, err = m.OpenStreamSync() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + deleteStream(6) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(2*maxOutgoingStreams + 2))) + }) + + It("stops waiting when an error is registered", func() { + testErr := errors.New("test error") + openMaxNumStreams() + for _, str := range m.streams { + str.(*MockStreamI).EXPECT().closeForShutdown(testErr) + } + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.OpenStreamSync() + Expect(err).To(MatchError(testErr)) + close(done) + }() + + Consistently(done).ShouldNot(BeClosed()) + m.CloseWithError(testErr) + Eventually(done).Should(BeClosed()) + }) + + It("immediately returns when OpenStreamSync is called after an error was registered", func() { + testErr := errors.New("test error") + m.CloseWithError(testErr) + _, err := m.OpenStreamSync() + Expect(err).To(MatchError(testErr)) + }) + }) + }) + + Context("accepting streams", func() { + It("does nothing if no stream is opened", func() { + var accepted bool + go func() { + _, _ = m.AcceptStream() + accepted = true + }() + Consistently(func() bool { return accepted }).Should(BeFalse()) + }) + + It("starts with stream 3", func() { + var str Stream + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + _, err := m.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + }) + + It("returns an implicitly opened stream, if a stream number is skipped", func() { + var str Stream + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + }) + + It("returns to multiple accepts", func() { + var str1, str2 Stream + done1 := make(chan struct{}) + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str1, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done1) + }() + go func() { + defer GinkgoRecover() + var err error + str2, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done2) + }() + _, err := m.GetOrOpenStream(5) // opens stream 3 and 5 + Expect(err).ToNot(HaveOccurred()) + Eventually(done1).Should(BeClosed()) + Eventually(done2).Should(BeClosed()) + Expect(str1.StreamID()).ToNot(Equal(str2.StreamID())) + Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(3 + 5)) + }) + + It("waits until a new stream is available", func() { + var str Stream + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + _, err := m.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + }) + + It("returns multiple streams on subsequent Accept calls, if available", func() { + var str Stream + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(5))) + }) + + It("blocks after accepting a stream", func() { + _, err := m.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + str, err := m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, _ = m.AcceptStream() + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + // make the go routine return + str.(*MockStreamI).EXPECT().closeForShutdown(gomock.Any()) + m.CloseWithError(errors.New("shut down")) + Eventually(done).Should(BeClosed()) + }) + + It("stops waiting when an error is registered", func() { + testErr := errors.New("testErr") + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.AcceptStream() + Expect(err).To(MatchError(testErr)) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + m.CloseWithError(testErr) + Eventually(done).Should(BeClosed()) + }) + It("immediately returns when Accept is called after an error was registered", func() { + testErr := errors.New("testErr") + m.CloseWithError(testErr) + _, err := m.AcceptStream() + Expect(err).To(MatchError(testErr)) + }) + }) + }) + + Context("as a client", func() { + BeforeEach(func() { + setNewStreamsMap(protocol.PerspectiveClient) + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) + }) + + Context("server-side streams", func() { + It("rejects streams with odd IDs", func() { + _, err := m.GetOrOpenStream(5) + Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5")) + }) + + It("rejects streams with odds IDs, which are lower than the highest server-side stream", func() { + _, err := m.GetOrOpenStream(6) + Expect(err).NotTo(HaveOccurred()) + _, err = m.GetOrOpenStream(5) + Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5")) + }) + + It("gets new streams", func() { + s, err := m.GetOrOpenStream(2) + Expect(err).NotTo(HaveOccurred()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(2))) + Expect(m.streams).To(HaveLen(1)) + Expect(m.numOutgoingStreams).To(BeZero()) + Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) + }) + + It("opens skipped streams", func() { + _, err := m.GetOrOpenStream(6) + Expect(err).NotTo(HaveOccurred()) + Expect(m.streams).To(HaveKey(protocol.StreamID(2))) + Expect(m.streams).To(HaveKey(protocol.StreamID(4))) + Expect(m.streams).To(HaveKey(protocol.StreamID(6))) + Expect(m.numOutgoingStreams).To(BeZero()) + Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) + }) + + It("doesn't reopen an already closed stream", func() { + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + deleteStream(3) + Expect(err).ToNot(HaveOccurred()) + str, err = m.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) + }) + + Context("client-side streams", func() { + It("starts with stream 3", func() { + s, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(s).ToNot(BeNil()) + Expect(s.StreamID()).To(BeEquivalentTo(3)) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) + Expect(m.numIncomingStreams).To(BeZero()) + }) + + It("opens multiple streams", func() { + s1, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + s2, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(s2.StreamID()).To(Equal(s1.StreamID() + 2)) + }) + + It("doesn't reopen an already closed stream", func() { + _, err := m.GetOrOpenStream(4) + Expect(err).ToNot(HaveOccurred()) + deleteStream(4) + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetOrOpenStream(4) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) + }) + + Context("accepting streams", func() { + It("accepts stream 2 first", func() { + var str Stream + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + _, err := m.GetOrOpenStream(2) + Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(2))) + }) + }) + }) + }) + + Context("deleting streams", func() { + BeforeEach(func() { + setNewStreamsMap(protocol.PerspectiveServer) + }) + + It("deletes an incoming stream", func() { + _, err := m.GetOrOpenStream(5) // open stream 3 and 5 + Expect(err).ToNot(HaveOccurred()) + Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) + err = m.DeleteStream(3) + Expect(err).ToNot(HaveOccurred()) + Expect(m.streams).To(HaveLen(1)) + Expect(m.streams).To(HaveKey(protocol.StreamID(5))) + Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) + }) + + It("deletes an outgoing stream", func() { + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) + _, err := m.OpenStream() // open stream 2 + Expect(err).ToNot(HaveOccurred()) + _, err = m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) + err = m.DeleteStream(2) + Expect(err).ToNot(HaveOccurred()) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) + }) + + It("errors when the stream doesn't exist", func() { + err := m.DeleteStream(1337) + Expect(err).To(MatchError(errMapAccess)) + }) + }) + + It("sets the flow control limit", func() { + setNewStreamsMap(protocol.PerspectiveServer) + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + m.streams[3].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: 3, + ByteOffset: 321, + }) + m.streams[5].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: 5, + ByteOffset: 321, + }) + m.UpdateLimits(&handshake.TransportParameters{StreamFlowControlWindow: 321}) + }) +}) diff --git a/streams_map_test.go b/streams_map_test.go index be1c16f8..ab5b5bc0 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -3,6 +3,7 @@ package quic import ( "errors" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -12,7 +13,7 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Streams Map", func() { +var _ = Describe("Streams Map (for IETF QUIC)", func() { var m *streamsMap newStream := func(id protocol.StreamID) streamI { @@ -21,8 +22,8 @@ var _ = Describe("Streams Map", func() { return str } - setNewStreamsMap := func(p protocol.Perspective, v protocol.VersionNumber) { - m = newStreamsMap(newStream, p, v) + setNewStreamsMap := func(p protocol.Perspective) { + m = newStreamsMap(newStream, p).(*streamsMap) } deleteStream := func(id protocol.StreamID) { @@ -32,15 +33,15 @@ var _ = Describe("Streams Map", func() { Context("getting and creating streams", func() { Context("as a server", func() { BeforeEach(func() { - setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames) + setNewStreamsMap(protocol.PerspectiveServer) }) Context("client-side streams", func() { It("gets new streams", func() { - s, err := m.GetOrOpenStream(3) + s, err := m.GetOrOpenStream(1) Expect(err).NotTo(HaveOccurred()) Expect(s).ToNot(BeNil()) - Expect(s.StreamID()).To(Equal(protocol.StreamID(3))) + Expect(s.StreamID()).To(Equal(protocol.StreamID(1))) Expect(m.streams).To(HaveLen(1)) Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) Expect(m.numOutgoingStreams).To(BeZero()) @@ -264,8 +265,7 @@ var _ = Describe("Streams Map", func() { Consistently(func() bool { return accepted }).Should(BeFalse()) }) - It("starts with stream 1, if the crypto stream is stream 0", func() { - setNewStreamsMap(protocol.PerspectiveServer, versionIETFFrames) + It("starts with stream 1", func() { var str Stream done := make(chan struct{}) go func() { @@ -281,7 +281,7 @@ var _ = Describe("Streams Map", func() { Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) }) - It("starts with stream 3, if the crypto stream is stream 1", func() { + It("returns an implicitly opened stream, if a stream number is skipped", func() { var str Stream done := make(chan struct{}) go func() { @@ -294,23 +294,7 @@ var _ = Describe("Streams Map", func() { _, err := m.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) Eventually(done).Should(BeClosed()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) - }) - - It("returns an implicitly opened stream, if a stream number is skipped", func() { - var str Stream - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - var err error - str, err = m.AcceptStream() - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - _, err := m.GetOrOpenStream(5) - Expect(err).ToNot(HaveOccurred()) - Eventually(done).Should(BeClosed()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) }) It("returns to multiple accepts", func() { @@ -331,12 +315,12 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) close(done2) }() - _, err := m.GetOrOpenStream(5) // opens stream 3 and 5 + _, err := m.GetOrOpenStream(3) // opens stream 1 and 3 Expect(err).ToNot(HaveOccurred()) Eventually(done1).Should(BeClosed()) Eventually(done2).Should(BeClosed()) Expect(str1.StreamID()).ToNot(Equal(str2.StreamID())) - Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(3 + 5)) + Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(1 + 3)) }) It("waits until a new stream is available", func() { @@ -350,10 +334,10 @@ var _ = Describe("Streams Map", func() { close(done) }() Consistently(done).ShouldNot(BeClosed()) - _, err := m.GetOrOpenStream(3) + _, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) Eventually(done).Should(BeClosed()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) }) It("returns multiple streams on subsequent Accept calls, if available", func() { @@ -366,39 +350,46 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) close(done) }() - _, err := m.GetOrOpenStream(5) + _, err := m.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) Eventually(done).Should(BeClosed()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) str, err = m.AcceptStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(5))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) }) It("blocks after accepting a stream", func() { - var accepted bool - _, err := m.GetOrOpenStream(3) + _, err := m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) str, err := m.AcceptStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + done := make(chan struct{}) go func() { defer GinkgoRecover() _, _ = m.AcceptStream() - accepted = true + close(done) }() - Consistently(func() bool { return accepted }).Should(BeFalse()) + Consistently(done).ShouldNot(BeClosed()) + // make the go routine return + str.(*MockStreamI).EXPECT().closeForShutdown(gomock.Any()) + m.CloseWithError(errors.New("shut down")) + Eventually(done).Should(BeClosed()) }) It("stops waiting when an error is registered", func() { testErr := errors.New("testErr") - var acceptErr error + done := make(chan struct{}) go func() { - _, acceptErr = m.AcceptStream() + defer GinkgoRecover() + _, err := m.AcceptStream() + Expect(err).To(MatchError(testErr)) + close(done) }() - Consistently(func() error { return acceptErr }).ShouldNot(HaveOccurred()) + Consistently(done).ShouldNot(BeClosed()) m.CloseWithError(testErr) - Eventually(func() error { return acceptErr }).Should(MatchError(testErr)) + Eventually(done).Should(BeClosed()) }) It("immediately returns when Accept is called after an error was registered", func() { @@ -412,7 +403,7 @@ var _ = Describe("Streams Map", func() { Context("as a client", func() { BeforeEach(func() { - setNewStreamsMap(protocol.PerspectiveClient, versionGQUICFrames) + setNewStreamsMap(protocol.PerspectiveClient) m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) }) @@ -451,18 +442,18 @@ var _ = Describe("Streams Map", func() { It("doesn't reopen an already closed stream", func() { str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) - deleteStream(3) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + deleteStream(1) Expect(err).ToNot(HaveOccurred()) - str, err = m.GetOrOpenStream(3) + str, err = m.GetOrOpenStream(1) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeNil()) }) }) Context("client-side streams", func() { - It("starts with stream 1, if the crypto stream is stream 0", func() { - setNewStreamsMap(protocol.PerspectiveClient, versionIETFFrames) + It("starts with stream 1", func() { + setNewStreamsMap(protocol.PerspectiveClient) m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) s, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -472,15 +463,6 @@ var _ = Describe("Streams Map", func() { Expect(m.numIncomingStreams).To(BeZero()) }) - It("starts with stream 3, if the crypto stream is stream 1", func() { - s, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(s).ToNot(BeNil()) - Expect(s.StreamID()).To(BeEquivalentTo(3)) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) - Expect(m.numIncomingStreams).To(BeZero()) - }) - It("opens multiple streams", func() { s1, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -522,17 +504,17 @@ var _ = Describe("Streams Map", func() { Context("deleting streams", func() { BeforeEach(func() { - setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames) + setNewStreamsMap(protocol.PerspectiveServer) }) It("deletes an incoming stream", func() { - _, err := m.GetOrOpenStream(5) // open stream 3 and 5 + _, err := m.GetOrOpenStream(3) // open stream 1 and 3 Expect(err).ToNot(HaveOccurred()) Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) - err = m.DeleteStream(3) + err = m.DeleteStream(1) Expect(err).ToNot(HaveOccurred()) Expect(m.streams).To(HaveLen(1)) - Expect(m.streams).To(HaveKey(protocol.StreamID(5))) + Expect(m.streams).To(HaveKey(protocol.StreamID(3))) Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) }) @@ -555,15 +537,15 @@ var _ = Describe("Streams Map", func() { }) It("sets the flow control limit", func() { - setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames) - _, err := m.GetOrOpenStream(5) + setNewStreamsMap(protocol.PerspectiveServer) + _, err := m.GetOrOpenStream(3) Expect(err).ToNot(HaveOccurred()) - m.streams[3].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ - StreamID: 3, + m.streams[1].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: 1, ByteOffset: 321, }) - m.streams[5].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ - StreamID: 5, + m.streams[3].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: 3, ByteOffset: 321, }) m.UpdateLimits(&handshake.TransportParameters{StreamFlowControlWindow: 321})