fuzzing: add frame validation logic (#4206)

This commit is contained in:
Marten Seemann
2023-12-14 12:39:02 +05:30
committed by GitHub
parent 048940927c
commit 6ffb9054a2
2 changed files with 55 additions and 5 deletions

View File

@@ -56,22 +56,23 @@ func Fuzz(data []byte) int {
continue continue
} }
} }
validateFrame(f)
startLen := len(b) startLen := len(b)
parsedLen := initialLen - len(data) parsedLen := initialLen - len(data)
b, err = f.Append(b, version) b, err = f.Append(b, version)
if err != nil { if err != nil {
panic(fmt.Sprintf("Error writing frame %#v: %s", f, err)) panic(fmt.Sprintf("error writing frame %#v: %s", f, err))
} }
frameLen := protocol.ByteCount(len(b) - startLen) frameLen := protocol.ByteCount(len(b) - startLen)
if f.Length(version) != frameLen { if f.Length(version) != frameLen {
panic(fmt.Sprintf("Inconsistent frame length for %#v: expected %d, got %d", f, frameLen, f.Length(version))) panic(fmt.Sprintf("inconsistent frame length for %#v: expected %d, got %d", f, frameLen, f.Length(version)))
} }
if sf, ok := f.(*wire.StreamFrame); ok { if sf, ok := f.(*wire.StreamFrame); ok {
sf.PutBack() sf.PutBack()
} }
if frameLen > protocol.ByteCount(parsedLen) { if frameLen > protocol.ByteCount(parsedLen) {
panic(fmt.Sprintf("Serialized length (%d) is longer than parsed length (%d)", len(b), parsedLen)) panic(fmt.Sprintf("serialized length (%d) is longer than parsed length (%d)", len(b), parsedLen))
} }
} }
@@ -80,3 +81,52 @@ func Fuzz(data []byte) int {
} }
return 1 return 1
} }
func validateFrame(frame wire.Frame) {
switch f := frame.(type) {
case *wire.StreamFrame:
if protocol.ByteCount(len(f.Data)) != f.DataLen() {
panic("STREAM frame: inconsistent data length")
}
case *wire.AckFrame:
if f.DelayTime < 0 {
panic(fmt.Sprintf("invalid ACK delay_time: %s", f.DelayTime))
}
if f.LargestAcked() < f.LowestAcked() {
panic("ACK: largest acknowledged is smaller than lowest acknowledged")
}
for _, r := range f.AckRanges {
if r.Largest < 0 || r.Smallest < 0 {
panic("ACK range contains a negative packet number")
}
}
if !f.AcksPacket(f.LargestAcked()) {
panic("ACK frame claims that largest acknowledged is not acknowledged")
}
if !f.AcksPacket(f.LowestAcked()) {
panic("ACK frame claims that lowest acknowledged is not acknowledged")
}
_ = f.AcksPacket(100)
_ = f.AcksPacket((f.LargestAcked() + f.LowestAcked()) / 2)
case *wire.NewConnectionIDFrame:
if f.ConnectionID.Len() < 1 || f.ConnectionID.Len() > 20 {
panic(fmt.Sprintf("invalid NEW_CONNECTION_ID frame length: %s", f.ConnectionID))
}
case *wire.NewTokenFrame:
if len(f.Token) == 0 {
panic("NEW_TOKEN frame with an empty token")
}
case *wire.MaxStreamsFrame:
if f.MaxStreamNum > protocol.MaxStreamCount {
panic("MAX_STREAMS frame with an invalid Maximum Streams value")
}
case *wire.StreamsBlockedFrame:
if f.StreamLimit > protocol.MaxStreamCount {
panic("STREAMS_BLOCKED frame with an invalid Maximum Streams value")
}
case *wire.ConnectionCloseFrame:
if f.IsApplicationError && f.FrameType != 0 {
panic("CONNECTION_CLOSE for an application error containing a frame type")
}
}
}

View File

@@ -37,7 +37,7 @@ func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponen
delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
if delayTime < 0 { if delayTime < 0 {
// If the delay time overflows, set it to the maximum encodable value. // If the delay time overflows, set it to the maximum encode-able value.
delayTime = utils.InfDuration delayTime = utils.InfDuration
} }
frame.DelayTime = delayTime frame.DelayTime = delayTime
@@ -57,9 +57,9 @@ func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponen
return errors.New("invalid first ACK range") return errors.New("invalid first ACK range")
} }
smallest := largestAcked - ackBlock smallest := largestAcked - ackBlock
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked})
// read all the other ACK ranges // read all the other ACK ranges
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked})
for i := uint64(0); i < numBlocks; i++ { for i := uint64(0); i < numBlocks; i++ {
g, err := quicvarint.Read(r) g, err := quicvarint.Read(r)
if err != nil { if err != nil {