diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index e9d3d6a5..da775f8c 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -3,6 +3,7 @@ package wire import ( "bytes" "fmt" + "reflect" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" @@ -29,7 +30,11 @@ func (p *frameParser) ParseNext(r *bytes.Reader, encLevel protocol.EncryptionLev } r.UnreadByte() - return p.parseFrame(r, typeByte, encLevel) + f, err := p.parseFrame(r, typeByte, encLevel) + if err != nil { + return nil, qerr.Error(qerr.FrameEncodingError, err.Error()) + } + return f, nil } return nil, nil } @@ -39,59 +44,72 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protoc var err error if typeByte&0xf8 == 0x8 { frame, err = parseStreamFrame(r, p.version) - if err != nil { - return nil, qerr.Error(qerr.FrameEncodingError, err.Error()) + } else { + switch typeByte { + case 0x1: + frame, err = parsePingFrame(r, p.version) + case 0x2, 0x3: + ackDelayExponent := p.ackDelayExponent + if encLevel != protocol.Encryption1RTT { + ackDelayExponent = protocol.DefaultAckDelayExponent + } + frame, err = parseAckFrame(r, ackDelayExponent, p.version) + case 0x4: + frame, err = parseResetStreamFrame(r, p.version) + case 0x5: + frame, err = parseStopSendingFrame(r, p.version) + case 0x6: + frame, err = parseCryptoFrame(r, p.version) + case 0x7: + frame, err = parseNewTokenFrame(r, p.version) + case 0x10: + frame, err = parseMaxDataFrame(r, p.version) + case 0x11: + frame, err = parseMaxStreamDataFrame(r, p.version) + case 0x12, 0x13: + frame, err = parseMaxStreamsFrame(r, p.version) + case 0x14: + frame, err = parseDataBlockedFrame(r, p.version) + case 0x15: + frame, err = parseStreamDataBlockedFrame(r, p.version) + case 0x16, 0x17: + frame, err = parseStreamsBlockedFrame(r, p.version) + case 0x18: + frame, err = parseNewConnectionIDFrame(r, p.version) + case 0x19: + frame, err = parseRetireConnectionIDFrame(r, p.version) + case 0x1a: + frame, err = parsePathChallengeFrame(r, p.version) + case 0x1b: + frame, err = parsePathResponseFrame(r, p.version) + case 0x1c, 0x1d: + frame, err = parseConnectionCloseFrame(r, p.version) + default: + err = fmt.Errorf("unknown type byte 0x%x", typeByte) } - return frame, nil - } - switch typeByte { - case 0x1: - frame, err = parsePingFrame(r, p.version) - case 0x2, 0x3: - ackDelayExponent := p.ackDelayExponent - if encLevel != protocol.Encryption1RTT { - ackDelayExponent = protocol.DefaultAckDelayExponent - } - frame, err = parseAckFrame(r, ackDelayExponent, p.version) - case 0x4: - frame, err = parseResetStreamFrame(r, p.version) - case 0x5: - frame, err = parseStopSendingFrame(r, p.version) - case 0x6: - frame, err = parseCryptoFrame(r, p.version) - case 0x7: - frame, err = parseNewTokenFrame(r, p.version) - case 0x10: - frame, err = parseMaxDataFrame(r, p.version) - case 0x11: - frame, err = parseMaxStreamDataFrame(r, p.version) - case 0x12, 0x13: - frame, err = parseMaxStreamsFrame(r, p.version) - case 0x14: - frame, err = parseDataBlockedFrame(r, p.version) - case 0x15: - frame, err = parseStreamDataBlockedFrame(r, p.version) - case 0x16, 0x17: - frame, err = parseStreamsBlockedFrame(r, p.version) - case 0x18: - frame, err = parseNewConnectionIDFrame(r, p.version) - case 0x19: - frame, err = parseRetireConnectionIDFrame(r, p.version) - case 0x1a: - frame, err = parsePathChallengeFrame(r, p.version) - case 0x1b: - frame, err = parsePathResponseFrame(r, p.version) - case 0x1c, 0x1d: - frame, err = parseConnectionCloseFrame(r, p.version) - default: - err = fmt.Errorf("unknown type byte 0x%x", typeByte) } if err != nil { - return nil, qerr.Error(qerr.FrameEncodingError, err.Error()) + return nil, err + } + if !p.isAllowedAtEncLevel(frame, encLevel) { + return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel) } return frame, nil } +func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { + switch encLevel { + case protocol.EncryptionInitial, protocol.EncryptionHandshake: + switch f.(type) { + case *CryptoFrame, *AckFrame, *ConnectionCloseFrame: + return true + } + case protocol.Encryption1RTT: + return true + } + return false +} + func (p *frameParser) SetAckDelayExponent(exp uint8) { p.ackDelayExponent = exp } diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index 8c30c91e..87f086d6 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -287,4 +287,71 @@ var _ = Describe("Frame parsing", func() { Expect(err).To(HaveOccurred()) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.FrameEncodingError)) }) + + Context("encryption level check", func() { + frames := []Frame{ + &PingFrame{}, + &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 42}}}, + &ResetStreamFrame{}, + &StopSendingFrame{}, + &CryptoFrame{}, + &NewTokenFrame{}, + &StreamFrame{Data: []byte("foobar")}, + &MaxDataFrame{}, + &MaxStreamDataFrame{}, + &MaxStreamsFrame{}, + &DataBlockedFrame{}, + &StreamDataBlockedFrame{}, + &StreamsBlockedFrame{}, + &NewConnectionIDFrame{ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}}, + &RetireConnectionIDFrame{}, + &PathChallengeFrame{}, + &PathResponseFrame{}, + &ConnectionCloseFrame{}, + } + + var framesSerialized [][]byte + + BeforeEach(func() { + framesSerialized = nil + for _, frame := range frames { + buf := &bytes.Buffer{} + Expect(frame.Write(buf, versionIETFFrames)).To(Succeed()) + framesSerialized = append(framesSerialized, buf.Bytes()) + } + }) + + It("rejects all frames but ACK, CRYPTO and CONNECTION_CLOSE in Initial packets", func() { + for i, b := range framesSerialized { + _, err := parser.ParseNext(bytes.NewReader(b), protocol.EncryptionInitial) + switch frames[i].(type) { + case *AckFrame, *ConnectionCloseFrame, *CryptoFrame: + Expect(err).ToNot(HaveOccurred()) + default: + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not allowed at encryption level Initial")) + } + } + }) + + It("rejects all frames but ACK, CRYPTO and CONNECTION_CLOSE in Handshake packets", func() { + for i, b := range framesSerialized { + _, err := parser.ParseNext(bytes.NewReader(b), protocol.EncryptionHandshake) + switch frames[i].(type) { + case *AckFrame, *ConnectionCloseFrame, *CryptoFrame: + Expect(err).ToNot(HaveOccurred()) + default: + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not allowed at encryption level Handshake")) + } + } + }) + + It("accepts all frame types in 1-RTT packets", func() { + for _, b := range framesSerialized { + _, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + } + }) + }) }) diff --git a/session.go b/session.go index b71e96c8..13f3eb7d 100644 --- a/session.go +++ b/session.go @@ -818,10 +818,6 @@ func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.E } func (s *session) handleStreamFrame(frame *wire.StreamFrame, encLevel protocol.EncryptionLevel) error { - // TODO(#1261): implement strict rules for frames types in unencrypted packets - if encLevel < protocol.Encryption1RTT { - return qerr.Error(qerr.ProtocolViolation, fmt.Sprintf("received unencrypted stream data on stream %d", frame.StreamID)) - } str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) if err != nil { return err diff --git a/session_test.go b/session_test.go index 03f2ebdf..d89d6f4a 100644 --- a/session_test.go +++ b/session_test.go @@ -142,14 +142,6 @@ var _ = Describe("Session", func() { }, protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) }) - - It("does not accept STREAM frames in non-1RTT packets", func() { - err := sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: 3, - Data: []byte("foobar"), - }, protocol.EncryptionHandshake) - Expect(err).To(MatchError(qerr.Error(qerr.ProtocolViolation, "received unencrypted stream data on stream 3"))) - }) }) Context("handling ACK frames", func() { @@ -540,12 +532,12 @@ var _ = Describe("Session", func() { Expect((&wire.PingFrame{}).Write(buf, sess.version)).To(Succeed()) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ packetNumber: 0x1337, - encryptionLevel: protocol.EncryptionHandshake, + encryptionLevel: protocol.Encryption1RTT, hdr: hdr, data: buf.Bytes(), }, nil) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionHandshake, rcvTime, true) + rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.Encryption1RTT, rcvTime, true) sess.receivedPacketHandler = rph packet := getPacket(hdr, nil) packet.rcvTime = rcvTime