From e957ad61842ea8ca168e64753ccb3e81ac93eefe Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 11 May 2016 18:47:06 +0700 Subject: [PATCH] check that all new Streams initiated by the client have an odd StreamID work towards #78 --- session.go | 13 +++++++++++-- session_test.go | 9 +++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/session.go b/session.go index 9c883c86..13c13fa4 100644 --- a/session.go +++ b/session.go @@ -23,6 +23,7 @@ type receivedPacket struct { } var ( + errInvalidStreamID = errors.New("STREAM_FRAME with invalid StreamID received") errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream") errWindowUpdateOnInvalidStream = errors.New("WINDOW_UPDATE received for unknown stream") ) @@ -225,9 +226,10 @@ func (s *Session) HandlePacket(remoteAddr interface{}, publicHeader *PublicHeade // TODO: Ignore data for closed streams func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error { - if frame.StreamID == 0 { - return errors.New("Session: 0 is not a valid Stream ID") + if !s.isValidStreamID(frame.StreamID) { + return errInvalidStreamID } + s.streamsMutex.RLock() str, streamExists := s.streams[frame.StreamID] s.streamsMutex.RUnlock() @@ -249,6 +251,13 @@ func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error { return nil } +func (s *Session) isValidStreamID(streamID protocol.StreamID) bool { + if streamID%2 != 1 { + return false + } + return true +} + func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error { if frame.StreamID == 0 { // TODO: handle connection level WindowUpdateFrames diff --git a/session_test.go b/session_test.go index 27d9c0de..dcf5df02 100644 --- a/session_test.go +++ b/session_test.go @@ -113,6 +113,15 @@ var _ = Describe("Session", func() { Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) }) + It("rejects streams with even StreamIDs", func() { + err := session.handleStreamFrame(&frames.StreamFrame{ + StreamID: 4, + Data: []byte{0xde, 0xca, 0xfb, 0xad}, + }) + Expect(err).To(HaveOccurred()) + Expect(err).To(Equal(errInvalidStreamID)) + }) + It("handles existing streams", func() { session.handleStreamFrame(&frames.StreamFrame{ StreamID: 5,