refactor stream ID parity check in the streams map

No functional change expected.
This commit is contained in:
Marten Seemann
2017-12-05 14:36:58 +07:00
parent ccd91a36b7
commit b286aae7a4
2 changed files with 18 additions and 24 deletions

View File

@@ -73,6 +73,14 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver pro
return &sm
}
// getStreamPerspective says which side should initiate a stream
func (m *streamsMap) streamInitiatedBy(id protocol.StreamID) protocol.Perspective {
if id%2 == 0 {
return protocol.PerspectiveServer
}
return protocol.PerspectiveClient
}
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
@@ -92,28 +100,15 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
return s, nil
}
if m.perspective == protocol.PerspectiveServer {
if id%2 == 0 {
if id <= m.nextStream { // this is a server-side stream that we already opened. Must have been closed already
if m.perspective == m.streamInitiatedBy(id) {
if id <= m.nextStream { // this is a stream opened by us. Must have been closed already
return nil, nil
}
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
}
if id <= m.highestStreamOpenedByPeer { // this is a client-side stream that doesn't exist anymore. Must have been closed already
if id <= m.highestStreamOpenedByPeer { // this is a peer-initiated stream that doesn't exist anymore. Must have been closed already
return nil, nil
}
}
if m.perspective == protocol.PerspectiveClient {
if id%2 == 1 {
if id <= m.nextStream { // this is a client-side stream that we already opened.
return nil, nil
}
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
}
if id <= m.highestStreamOpenedByPeer { // this is a server-side stream that doesn't exist anymore. Must have been closed already
return nil, nil
}
}
// sid is the next stream that will be opened
sid := m.highestStreamOpenedByPeer + 2
@@ -123,8 +118,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
}
for ; sid <= id; sid += 2 {
_, err := m.openRemoteStream(sid)
if err != nil {
if _, err := m.openRemoteStream(sid); err != nil {
return nil, err
}
}

View File

@@ -64,14 +64,14 @@ var _ = Describe("Streams Map", func() {
It("rejects streams with even IDs", func() {
_, err := m.GetOrOpenStream(6)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side"))
Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 6"))
})
It("rejects streams with even IDs, which are lower thatn the highest client-side stream", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())
_, err = m.GetOrOpenStream(4)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 4 from client-side"))
Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 4"))
})
It("gets existing streams", func() {
@@ -421,14 +421,14 @@ var _ = Describe("Streams Map", func() {
Context("client-side streams", func() {
It("rejects streams with odd IDs", func() {
_, err := m.GetOrOpenStream(5)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5"))
})
It("rejects streams with odds IDs, which are lower thatn the highest server-side stream", func() {
_, err := m.GetOrOpenStream(6)
Expect(err).NotTo(HaveOccurred())
_, err = m.GetOrOpenStream(5)
Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side"))
Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5"))
})
It("gets new streams", func() {