From 2fbc994d29a852111043fe3d3804d5a0a64be87f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 25 Mar 2018 18:33:46 +0200 Subject: [PATCH] move the unencrypted STREAM frame check from the unpacker to the session --- mint_utils.go | 4 ++++ mint_utils_test.go | 2 +- packet_unpacker.go | 6 ------ packet_unpacker_test.go | 47 ----------------------------------------- session.go | 6 ++++-- session_test.go | 33 +++++++++++++++++++++++++---- 6 files changed, 38 insertions(+), 60 deletions(-) diff --git a/mint_utils.go b/mint_utils.go index 5156a652..b65de21e 100644 --- a/mint_utils.go +++ b/mint_utils.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "crypto/x509" "errors" + "fmt" "io" "github.com/bifurcation/mint" @@ -123,6 +124,9 @@ func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, versio if frame == nil { return nil, errors.New("Packet doesn't contain a STREAM_FRAME") } + if frame.StreamID != version.CryptoStreamID() { + return nil, fmt.Errorf("Received STREAM_FRAME for wrong stream (Stream ID %d)", frame.StreamID) + } // We don't need a check for the stream ID here. // The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream. if frame.Offset != 0 { diff --git a/mint_utils_test.go b/mint_utils_test.go index d1afd804..a0905ca9 100644 --- a/mint_utils_test.go +++ b/mint_utils_test.go @@ -125,7 +125,7 @@ var _ = Describe("Packing and unpacking Initial packets", func() { } p := packPacket([]wire.Frame{f}) _, err := unpackInitialPacket(aead, hdr, p, ver) - Expect(err).To(MatchError("UnencryptedStreamData: received unencrypted stream data on stream 42")) + Expect(err).To(MatchError("Received STREAM_FRAME for wrong stream (Stream ID 42)")) }) It("rejects a packet that has a STREAM_FRAME with a non-zero offset", func() { diff --git a/packet_unpacker.go b/packet_unpacker.go index 0a98dc06..730592f3 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -2,7 +2,6 @@ package quic import ( "bytes" - "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -49,11 +48,6 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []by if frame == nil { break } - if sf, ok := frame.(*wire.StreamFrame); ok { - if sf.StreamID != u.version.CryptoStreamID() && encryptionLevel <= protocol.EncryptionUnencrypted { - return nil, qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", sf.StreamID)) - } - } fs = append(fs, frame) } diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 50548d7a..d407811f 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -71,51 +71,4 @@ var _ = Describe("Packet unpacker", func() { Expect(err).ToNot(HaveOccurred()) Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionSecure)) }) - - Context("unpacking STREAM frames", func() { - BeforeEach(func() { - unpacker.version = versionGQUICFrames - }) - - It("unpacks unencrypted STREAM frames on the crypto stream", func() { - unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionUnencrypted - f := &wire.StreamFrame{ - StreamID: versionGQUICFrames.CryptoStreamID(), - Data: []byte("foobar"), - } - err := f.Write(buf, versionGQUICFrames) - Expect(err).ToNot(HaveOccurred()) - setData(buf.Bytes()) - packet, err := unpacker.Unpack(hdrBin, hdr, data) - Expect(err).ToNot(HaveOccurred()) - Expect(packet.frames).To(Equal([]wire.Frame{f})) - }) - - It("unpacks encrypted STREAM frames on the crypto stream", func() { - unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionSecure - f := &wire.StreamFrame{ - StreamID: versionGQUICFrames.CryptoStreamID(), - Data: []byte("foobar"), - } - err := f.Write(buf, versionGQUICFrames) - Expect(err).ToNot(HaveOccurred()) - setData(buf.Bytes()) - packet, err := unpacker.Unpack(hdrBin, hdr, data) - Expect(err).ToNot(HaveOccurred()) - Expect(packet.frames).To(Equal([]wire.Frame{f})) - }) - - It("does not unpack unencrypted STREAM frames on higher streams", func() { - unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionUnencrypted - f := &wire.StreamFrame{ - StreamID: 3, - Data: []byte("foobar"), - } - err := f.Write(buf, versionGQUICFrames) - Expect(err).ToNot(HaveOccurred()) - setData(buf.Bytes()) - _, err = unpacker.Unpack(hdrBin, hdr, data) - Expect(err).To(MatchError(qerr.Error(qerr.UnencryptedStreamData, "received unencrypted stream data on stream 3"))) - }) - }) }) diff --git a/session.go b/session.go index 622b668e..a5cccbed 100644 --- a/session.go +++ b/session.go @@ -572,7 +572,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve wire.LogFrame(ff, false) switch frame := ff.(type) { case *wire.StreamFrame: - err = s.handleStreamFrame(frame) + err = s.handleStreamFrame(frame, encLevel) case *wire.AckFrame: err = s.handleAckFrame(frame, encLevel) case *wire.ConnectionCloseFrame: @@ -615,12 +615,14 @@ func (s *session) handlePacket(p *receivedPacket) { } } -func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { +func (s *session) handleStreamFrame(frame *wire.StreamFrame, encLevel protocol.EncryptionLevel) error { if frame.StreamID == s.version.CryptoStreamID() { if frame.FinBit { return errors.New("Received STREAM frame with FIN bit for the crypto stream") } return s.cryptoStream.handleStreamFrame(frame) + } else if encLevel <= protocol.EncryptionUnencrypted { + return qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", frame.StreamID)) } str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) if err != nil { diff --git a/session_test.go b/session_test.go index 3b57a6b5..796d13e2 100644 --- a/session_test.go +++ b/session_test.go @@ -203,7 +203,7 @@ var _ = Describe("Session", func() { str := NewMockReceiveStreamI(mockCtrl) str.EXPECT().handleStreamFrame(f) streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(str, nil) - err := sess.handleStreamFrame(f) + err := sess.handleStreamFrame(f, protocol.EncryptionForwardSecure) Expect(err).ToNot(HaveOccurred()) }) @@ -216,7 +216,7 @@ var _ = Describe("Session", func() { str := NewMockReceiveStreamI(mockCtrl) str.EXPECT().handleStreamFrame(f).Return(testErr) streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(str, nil) - err := sess.handleStreamFrame(f) + err := sess.handleStreamFrame(f, protocol.EncryptionForwardSecure) Expect(err).To(MatchError(testErr)) }) @@ -225,7 +225,7 @@ var _ = Describe("Session", func() { err := sess.handleStreamFrame(&wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), - }) + }, protocol.EncryptionForwardSecure) Expect(err).ToNot(HaveOccurred()) }) @@ -234,9 +234,34 @@ var _ = Describe("Session", func() { StreamID: sess.version.CryptoStreamID(), Offset: 0x1337, FinBit: true, - }) + }, protocol.EncryptionForwardSecure) Expect(err).To(MatchError("Received STREAM frame with FIN bit for the crypto stream")) }) + + It("accepts unencrypted STREAM frames on the crypto stream", func() { + f := &wire.StreamFrame{ + StreamID: versionGQUICFrames.CryptoStreamID(), + Data: []byte("foobar"), + } + err := sess.handleStreamFrame(f, protocol.EncryptionUnencrypted) + Expect(err).ToNot(HaveOccurred()) + }) + + It("unpacks encrypted STREAM frames on the crypto stream", func() { + err := sess.handleStreamFrame(&wire.StreamFrame{ + StreamID: versionGQUICFrames.CryptoStreamID(), + Data: []byte("foobar"), + }, protocol.EncryptionSecure) + Expect(err).ToNot(HaveOccurred()) + }) + + It("does not unpack unencrypted STREAM frames on higher streams", func() { + err := sess.handleStreamFrame(&wire.StreamFrame{ + StreamID: 3, + Data: []byte("foobar"), + }, protocol.EncryptionUnencrypted) + Expect(err).To(MatchError(qerr.Error(qerr.UnencryptedStreamData, "received unencrypted stream data on stream 3"))) + }) }) Context("handling ACK frames", func() {