diff --git a/streams_map.go b/streams_map.go index acf68b9d5..907258cfd 100644 --- a/streams_map.go +++ b/streams_map.go @@ -20,7 +20,7 @@ type streamsMap struct { openStreams []protocol.StreamID roundRobinIndex int - nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() + nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() highestStreamOpenedByPeer protocol.StreamID nextStreamOrErrCond sync.Cond openStreamOrErrCond sync.Cond @@ -58,18 +58,22 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver pro sm.nextStreamOrErrCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex - nextOddStream := protocol.StreamID(1) - if ver.CryptoStreamID() == protocol.StreamID(1) { - nextOddStream = 3 + nextClientInitiatedStream := protocol.StreamID(1) + nextServerInitiatedStream := protocol.StreamID(2) + if !ver.UsesTLS() { + nextServerInitiatedStream = 2 + nextClientInitiatedStream = 3 + if pers == protocol.PerspectiveServer { + sm.highestStreamOpenedByPeer = 1 + } } - if pers == protocol.PerspectiveClient { - sm.nextStream = nextOddStream - sm.nextStreamToAccept = 2 + if pers == protocol.PerspectiveServer { + sm.nextStreamToOpen = nextServerInitiatedStream + sm.nextStreamToAccept = nextClientInitiatedStream } else { - sm.nextStream = 2 - sm.nextStreamToAccept = nextOddStream + sm.nextStreamToOpen = nextClientInitiatedStream + sm.nextStreamToAccept = nextServerInitiatedStream } - return &sm } @@ -81,6 +85,13 @@ func (m *streamsMap) streamInitiatedBy(id protocol.StreamID) protocol.Perspectiv return protocol.PerspectiveClient } +func (m *streamsMap) nextStreamID(id protocol.StreamID) protocol.StreamID { + if m.perspective == protocol.PerspectiveServer && id == 0 { + return 1 + } + return id + 2 +} + // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. // Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { @@ -101,7 +112,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { } if m.perspective == m.streamInitiatedBy(id) { - if id <= m.nextStream { // this is a stream opened by us. Must have been closed already + if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already return nil, nil } return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) @@ -110,14 +121,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { return nil, nil } - // sid is the next stream that will be opened - sid := m.highestStreamOpenedByPeer + 2 - // if there is no stream opened yet, and this is the server, stream 1 should be openend - if sid == 2 && m.perspective == protocol.PerspectiveServer { - sid = 1 - } - - for ; sid <= id; sid += 2 { + for sid := m.nextStreamID(m.highestStreamOpenedByPeer); sid <= id; sid = m.nextStreamID(sid) { if _, err := m.openRemoteStream(sid); err != nil { return nil, err } @@ -146,15 +150,14 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (streamI, error) { } func (m *streamsMap) openStreamImpl() (streamI, error) { - id := m.nextStream if m.numOutgoingStreams >= m.maxOutgoingStreams { return nil, qerr.TooManyOpenStreams } m.numOutgoingStreams++ - m.nextStream += 2 - s := m.newStream(id) + s := m.newStream(m.nextStreamToOpen) m.putStream(s) + m.nextStreamToOpen = m.nextStreamID(m.nextStreamToOpen) return s, nil } diff --git a/streams_map_test.go b/streams_map_test.go index a5c04d0c4..3d7f56376 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -54,9 +54,11 @@ var _ = Describe("Streams Map", func() { Context("client-side streams", func() { It("gets new streams", func() { - s, err := m.GetOrOpenStream(1) + s, err := m.GetOrOpenStream(3) Expect(err).NotTo(HaveOccurred()) - Expect(s.StreamID()).To(Equal(protocol.StreamID(1))) + Expect(s).ToNot(BeNil()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(3))) + Expect(m.streams).To(HaveLen(1)) Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) Expect(m.numOutgoingStreams).To(BeZero()) }) @@ -93,11 +95,11 @@ var _ = Describe("Streams Map", func() { }) It("opens skipped streams", func() { - _, err := m.GetOrOpenStream(5) + _, err := m.GetOrOpenStream(7) Expect(err).NotTo(HaveOccurred()) - Expect(m.streams).To(HaveKey(protocol.StreamID(1))) Expect(m.streams).To(HaveKey(protocol.StreamID(3))) Expect(m.streams).To(HaveKey(protocol.StreamID(5))) + Expect(m.streams).To(HaveKey(protocol.StreamID(7))) }) It("doesn't reopen an already closed stream", func() { @@ -121,7 +123,7 @@ var _ = Describe("Streams Map", func() { }) It("errors when too many streams are opened implicitely", func() { - _, err := m.GetOrOpenStream(protocol.StreamID(m.maxIncomingStreams*2 + 1)) + _, err := m.GetOrOpenStream(protocol.StreamID(m.maxIncomingStreams*2 + 3)) Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) @@ -423,7 +425,7 @@ var _ = Describe("Streams Map", func() { Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5")) }) - It("rejects streams with odds IDs, which are lower thatn the highest server-side stream", func() { + It("rejects streams with odds IDs, which are lower than the highest server-side stream", func() { _, err := m.GetOrOpenStream(6) Expect(err).NotTo(HaveOccurred()) _, err = m.GetOrOpenStream(5) @@ -434,6 +436,7 @@ var _ = Describe("Streams Map", func() { s, err := m.GetOrOpenStream(2) Expect(err).NotTo(HaveOccurred()) Expect(s.StreamID()).To(Equal(protocol.StreamID(2))) + Expect(m.streams).To(HaveLen(1)) Expect(m.numOutgoingStreams).To(BeZero()) Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) })