diff --git a/session.go b/session.go index 467c992e..e282c69a 100644 --- a/session.go +++ b/session.go @@ -563,6 +563,9 @@ func (s *session) handlePacket(p *receivedPacket) { func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { if frame.StreamID == s.version.CryptoStreamID() { + if frame.FinBit { + return errors.New("Received STREAM frame with FIN bit for the crypto stream") + } return s.cryptoStream.AddStreamFrame(frame) } str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) @@ -582,6 +585,10 @@ func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) { } func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { + if frame.StreamID == s.version.CryptoStreamID() { + s.cryptoStream.UpdateSendWindow(frame.ByteOffset) + return nil + } str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) if err != nil { return err @@ -595,6 +602,9 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error } func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { + if frame.StreamID == s.version.CryptoStreamID() { + return errors.New("Received RST_STREAM frame for the crypto stream") + } str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) if err != nil { return err diff --git a/session_test.go b/session_test.go index 52c9ad9c..d2a6efe4 100644 --- a/session_test.go +++ b/session_test.go @@ -269,7 +269,7 @@ var _ = Describe("Session", func() { } }) - Context("when handling STREAM frames", func() { + Context("handling STREAM frames", func() { BeforeEach(func() { sess.streamsMap.UpdateMaxStreamLimit(100) }) @@ -330,6 +330,15 @@ var _ = Describe("Session", func() { }) Expect(err).ToNot(HaveOccurred()) }) + + It("errors on a STREAM frame that would close the crypto stream", func() { + err := sess.handleStreamFrame(&wire.StreamFrame{ + StreamID: sess.version.CryptoStreamID(), + Offset: 0x1337, + FinBit: true, + }) + Expect(err).To(MatchError("Received STREAM frame with FIN bit for the crypto stream")) + }) }) Context("handling RST_STREAM frames", func() { @@ -383,6 +392,14 @@ var _ = Describe("Session", func() { }}, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) }) + + It("erros when a RST_STREAM frame would reset the crypto stream", func() { + err := sess.handleRstStreamFrame(&wire.RstStreamFrame{ + StreamID: sess.version.CryptoStreamID(), + ErrorCode: 123, + }) + Expect(err).To(MatchError("Received RST_STREAM frame for the crypto stream")) + }) }) Context("handling MAX_DATA and MAX_STREAM_DATA frames", func() { @@ -393,6 +410,18 @@ var _ = Describe("Session", func() { sess.connFlowController = connFC }) + It("updates the flow control window of the crypto stream", func() { + fc := mocks.NewMockStreamFlowController(mockCtrl) + offset := protocol.ByteCount(0x4321) + fc.EXPECT().UpdateSendWindow(offset) + sess.cryptoStream.(*stream).flowController = fc + err := sess.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: sess.version.CryptoStreamID(), + ByteOffset: offset, + }) + Expect(err).ToNot(HaveOccurred()) + }) + It("updates the flow control window of a stream", func() { offset := protocol.ByteCount(0x1234) str, err := sess.GetOrOpenStream(5)