diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index 91008bdd6..88f1c4467 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -29,81 +29,45 @@ func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame if typeByte&0xf8 == 0x10 { frame, err = parseStreamFrame(r, v) if err != nil { - err = qerr.Error(qerr.InvalidStreamData, err.Error()) + return nil, qerr.Error(qerr.InvalidFrameData, err.Error()) } - return frame, err + return frame, nil } // TODO: implement all IETF QUIC frame types switch typeByte { case 0x1: frame, err = parseResetStreamFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidRstStreamData, err.Error()) - } case 0x2, 0x3: frame, err = parseConnectionCloseFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error()) - } case 0x4: frame, err = parseMaxDataFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) - } case 0x5: frame, err = parseMaxStreamDataFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) - } case 0x7: frame, err = parsePingFrame(r, v) case 0x8: frame, err = parseDataBlockedFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidBlockedData, err.Error()) - } case 0x9: frame, err = parseStreamDataBlockedFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidBlockedData, err.Error()) - } case 0xa, 0xb: frame, err = parseStreamsBlockedFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidFrameData, err.Error()) - } case 0xc: frame, err = parseStopSendingFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidFrameData, err.Error()) - } case 0xe: frame, err = parsePathChallengeFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidFrameData, err.Error()) - } case 0xf: frame, err = parsePathResponseFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidFrameData, err.Error()) - } case 0x1a, 0x1b: frame, err = parseAckFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidAckData, err.Error()) - } case 0x1c, 0x1d: frame, err = parseMaxStreamsFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidFrameData, err.Error()) - } case 0x18: frame, err = parseCryptoFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidFrameData, err.Error()) - } default: - err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte)) + err = fmt.Errorf("unknown type byte 0x%x", typeByte) } - return frame, err + if err != nil { + return nil, qerr.Error(qerr.InvalidFrameData, err.Error()) + } + return frame, nil } diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index f3cec455e..be884c470 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -208,25 +208,14 @@ var _ = Describe("Frame parsing", func() { }) It("errors on invalid frames", func() { - for b, e := range map[byte]qerr.ErrorCode{ - 0x01: qerr.InvalidRstStreamData, - 0x02: qerr.InvalidConnectionCloseData, - 0x04: qerr.InvalidWindowUpdateData, - 0x05: qerr.InvalidWindowUpdateData, - 0x06: qerr.InvalidFrameData, - 0x08: qerr.InvalidBlockedData, - 0x09: qerr.InvalidBlockedData, - 0x0a: qerr.InvalidFrameData, - 0x0c: qerr.InvalidFrameData, - 0x0e: qerr.InvalidFrameData, - 0x0f: qerr.InvalidFrameData, - 0x10: qerr.InvalidStreamData, - 0x1a: qerr.InvalidAckData, - 0x1b: qerr.InvalidAckData, - } { - _, err := ParseNextFrame(bytes.NewReader([]byte{b}), versionIETFFrames) - Expect(err).To(HaveOccurred()) - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(e)) + f := &MaxStreamDataFrame{ + StreamID: 0x1337, + ByteOffset: 0xdeadbeef, } + b := &bytes.Buffer{} + f.Write(b, versionIETFFrames) + _, err := ParseNextFrame(bytes.NewReader(b.Bytes()[:b.Len()-2]), versionIETFFrames) + Expect(err).To(HaveOccurred()) + Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidFrameData)) }) })