diff --git a/internal/wire/header.go b/internal/wire/header.go index 1274e83da..678a04a24 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "encoding/binary" "errors" "fmt" @@ -40,37 +39,27 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti // https://datatracker.ietf.org/doc/html/rfc8999#section-5.1. // This function should only be called on Long Header packets for which we don't support the version. func ParseArbitraryLenConnectionIDs(data []byte) (bytesParsed int, dest, src protocol.ArbitraryLenConnectionID, _ error) { - r := bytes.NewReader(data) - remaining := r.Len() - src, dest, err := parseArbitraryLenConnectionIDs(r) - return remaining - r.Len(), src, dest, err -} - -func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.ArbitraryLenConnectionID, _ error) { - r.Seek(5, io.SeekStart) // skip first byte and version field - destConnIDLen, err := r.ReadByte() - if err != nil { - return nil, nil, err + startLen := len(data) + if len(data) < 6 { + return 0, nil, nil, io.EOF } + data = data[5:] // skip first byte and version field + destConnIDLen := data[0] + data = data[1:] destConnID := make(protocol.ArbitraryLenConnectionID, destConnIDLen) - if _, err := io.ReadFull(r, destConnID); err != nil { - if err == io.ErrUnexpectedEOF { - err = io.EOF - } - return nil, nil, err + if len(data) < int(destConnIDLen)+1 { + return 0, nil, nil, io.EOF } - srcConnIDLen, err := r.ReadByte() - if err != nil { - return nil, nil, err + copy(destConnID, data) + data = data[destConnIDLen:] + srcConnIDLen := data[0] + data = data[1:] + if len(data) < int(srcConnIDLen) { + return 0, nil, nil, io.EOF } srcConnID := make(protocol.ArbitraryLenConnectionID, srcConnIDLen) - if _, err := io.ReadFull(r, srcConnID); err != nil { - if err == io.ErrUnexpectedEOF { - err = io.EOF - } - return nil, nil, err - } - return destConnID, srcConnID, nil + copy(srcConnID, data) + return startLen - len(data) + int(srcConnIDLen), destConnID, srcConnID, nil } func IsPotentialQUICPacket(firstByte byte) bool { diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 64fdba560..aad1eef00 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -575,3 +575,39 @@ func BenchmarkParseRetry(b *testing.B) { } } } + +func BenchmarkArbitraryHeaderParsing(b *testing.B) { + b.Run("dest 8/ src 10", func(b *testing.B) { benchmarkArbitraryHeaderParsing(b, 8, 10) }) + b.Run("dest 20 / src 20", func(b *testing.B) { benchmarkArbitraryHeaderParsing(b, 20, 20) }) + b.Run("dest 100 / src 150", func(b *testing.B) { benchmarkArbitraryHeaderParsing(b, 100, 150) }) +} + +func benchmarkArbitraryHeaderParsing(b *testing.B, destLen, srcLen int) { + destConnID := make([]byte, destLen) + rand.Read(destConnID) + srcConnID := make([]byte, srcLen) + rand.Read(srcConnID) + buf := []byte{0x80, 1, 2, 3, 4} + buf = append(buf, uint8(destLen)) + buf = append(buf, destConnID...) + buf = append(buf, uint8(srcLen)) + buf = append(buf, srcConnID...) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + parsed, d, s, err := ParseArbitraryLenConnectionIDs(buf) + if err != nil { + b.Fatal(err) + } + if parsed != len(buf) { + b.Fatal("expected to parse entire slice") + } + if !bytes.Equal(destConnID, d.Bytes()) { + b.Fatalf("destination connection IDs don't match: %v vs %v", destConnID, d.Bytes()) + } + if !bytes.Equal(srcConnID, s.Bytes()) { + b.Fatalf("source connection IDs don't match: %v vs %v", srcConnID, s.Bytes()) + } + } +} diff --git a/quicvarint/varint_test.go b/quicvarint/varint_test.go index 104b0e620..6665f2ecb 100644 --- a/quicvarint/varint_test.go +++ b/quicvarint/varint_test.go @@ -313,3 +313,24 @@ func benchmarkParse(b *testing.B, inputs []benchmarkValue) { } } } + +func BenchmarkAppend(b *testing.B) { + b.Run("1-byte", func(b *testing.B) { benchmarkAppend(b, randomValues(min(b.N, 1024), maxVarInt1)) }) + b.Run("2-byte", func(b *testing.B) { benchmarkAppend(b, randomValues(min(b.N, 1024), maxVarInt2)) }) + b.Run("4-byte", func(b *testing.B) { benchmarkAppend(b, randomValues(min(b.N, 1024), maxVarInt4)) }) + b.Run("8-byte", func(b *testing.B) { benchmarkAppend(b, randomValues(min(b.N, 1024), maxVarInt8)) }) +} + +func benchmarkAppend(b *testing.B, inputs []benchmarkValue) { + buf := make([]byte, 8) + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf = buf[:0] + index := i % 1024 + buf = Append(buf, inputs[index].v) + + if !bytes.Equal(buf, inputs[index].b) { + b.Fatalf("expected to write %v, wrote %v", inputs[i].b, buf) + } + } +}