diff --git a/streams_map.go b/streams_map.go index 0517685df..e8308f69f 100644 --- a/streams_map.go +++ b/streams_map.go @@ -8,6 +8,7 @@ import ( "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" + "github.com/lucas-clemente/quic-go/utils" ) type streamsMap struct { @@ -23,6 +24,8 @@ type streamsMap struct { nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() highestStreamOpenedByPeer protocol.StreamID + nextStreamCond sync.Cond + nextStreamToAccept protocol.StreamID newStream newStreamLambda @@ -47,11 +50,14 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connect newStream: newStream, connectionParameters: connectionParameters, } + sm.nextStreamCond.L = &sm.mutex if pers == protocol.PerspectiveClient { sm.nextStream = 1 + sm.nextStreamToAccept = 2 } else { sm.nextStream = 2 + sm.nextStreamToAccept = 1 } return &sm @@ -94,6 +100,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { sid -= 2 } + m.nextStreamCond.Broadcast() return m.streams[id], nil } @@ -156,6 +163,25 @@ func (m *streamsMap) OpenStream() (*stream, error) { return s, nil } +// AcceptStream returns the next stream opened by the peer +// it blocks until a new stream is opened +// TODO: implement error conditions +func (m *streamsMap) AcceptStream() (utils.Stream, error) { + m.mutex.Lock() + var str utils.Stream + for { + var ok bool + str, ok = m.streams[m.nextStreamToAccept] + if ok { + break + } + m.nextStreamCond.Wait() + } + m.nextStreamToAccept += 2 + m.mutex.Unlock() + return str, nil +} + func (m *streamsMap) Iterate(fn streamLambda) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/streams_map_test.go b/streams_map_test.go index 9277e6b7b..b8a00241e 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -8,6 +8,7 @@ import ( "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" + "github.com/lucas-clemente/quic-go/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -210,6 +211,114 @@ var _ = Describe("Streams Map", func() { }) }) }) + + Context("accepting streams", func() { + It("does nothing if no stream is opened", func() { + var accepted bool + go func() { + _, _ = m.AcceptStream() + accepted = true + }() + Consistently(func() bool { return accepted }).Should(BeFalse()) + }) + + It("accepts stream 1 first", func() { + var str utils.Stream + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + }() + _, err := m.GetOrOpenStream(1) + Expect(err).ToNot(HaveOccurred()) + Eventually(func() utils.Stream { return str }).ShouldNot(BeNil()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + }) + + It("returns an implicitly opened stream, if a stream number is skipped", func() { + var str utils.Stream + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + }() + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + Eventually(func() utils.Stream { return str }).ShouldNot(BeNil()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + }) + + It("returns to multiple accepts", func() { + var str1, str2 utils.Stream + go func() { + defer GinkgoRecover() + var err error + str1, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + }() + go func() { + defer GinkgoRecover() + var err error + str2, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + }() + _, err := m.GetOrOpenStream(3) // opens stream 1 and 3 + Expect(err).ToNot(HaveOccurred()) + Eventually(func() utils.Stream { return str1 }).ShouldNot(BeNil()) + Eventually(func() utils.Stream { return str2 }).ShouldNot(BeNil()) + Expect(str1.StreamID()).ToNot(Equal(str2.StreamID())) + Expect(str1.StreamID() + str2.StreamID()).To(BeEquivalentTo(1 + 3)) + }) + + It("waits a new stream is available", func() { + var str utils.Stream + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + }() + Consistently(func() utils.Stream { return str }).Should(BeNil()) + _, err := m.GetOrOpenStream(1) + Expect(err).ToNot(HaveOccurred()) + Eventually(func() utils.Stream { return str }).ShouldNot(BeNil()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + }) + + It("returns multiple streams on subsequent Accept calls, if available", func() { + var str utils.Stream + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + }() + _, err := m.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + Eventually(func() utils.Stream { return str }).ShouldNot(BeNil()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) + }) + + It("blocks after accepting a stream", func() { + var accepted bool + _, err := m.GetOrOpenStream(1) + Expect(err).ToNot(HaveOccurred()) + str, err := m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + go func() { + defer GinkgoRecover() + _, _ = m.AcceptStream() + accepted = true + }() + Consistently(func() bool { return accepted }).Should(BeFalse()) + }) + }) }) Context("as a client", func() { @@ -258,6 +367,22 @@ var _ = Describe("Streams Map", func() { Expect(s2.StreamID()).To(Equal(s1.StreamID() + 2)) }) }) + + Context("accepting streams", func() { + It("accepts stream 2 first", func() { + var str utils.Stream + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + }() + _, err := m.GetOrOpenStream(2) + Expect(err).ToNot(HaveOccurred()) + Eventually(func() utils.Stream { return str }).ShouldNot(BeNil()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(2))) + }) + }) }) })