mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Adding dispatcher architecture (#3300)
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
This commit is contained in:
committed by
GitHub
parent
44f481a45c
commit
9e049a32a1
171
dispatcher/examples/gemm/python/03_benchmark.py
Normal file
171
dispatcher/examples/gemm/python/03_benchmark.py
Normal file
@@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 03: Benchmark
|
||||
|
||||
Performance benchmarking with compute-optimized kernel configuration.
|
||||
|
||||
Complexity: ★★★☆☆
|
||||
|
||||
Usage:
|
||||
python3 03_benchmark.py
|
||||
python3 03_benchmark.py --help
|
||||
python3 03_benchmark.py --size 4096
|
||||
python3 03_benchmark.py --dtype bf16 --iterations 20
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM Benchmark Example - performance testing",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 03_benchmark.py # Default benchmark suite
|
||||
python3 03_benchmark.py --size 4096 # Single size benchmark
|
||||
python3 03_benchmark.py --dtype bf16 # BF16 benchmark
|
||||
python3 03_benchmark.py --iterations 20 # More iterations
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="bf16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: bf16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Single problem size MxNxK (default: run all sizes)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup", type=int, default=3, help="Warmup iterations (default: 3)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations", type=int, default=10, help="Benchmark iterations (default: 10)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 03: Benchmark")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Setup dispatcher with compute-optimized config
|
||||
# =========================================================================
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
pipeline="compv4",
|
||||
scheduler="intrawave",
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(config, registry_name="benchmark", verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Benchmark
|
||||
# =========================================================================
|
||||
print("\nStep 2: Benchmark")
|
||||
|
||||
if args.size > 0:
|
||||
sizes = [(args.size, args.size, args.size)]
|
||||
else:
|
||||
sizes = [
|
||||
(512, 512, 512),
|
||||
(1024, 1024, 1024),
|
||||
(2048, 2048, 2048),
|
||||
(4096, 4096, 4096),
|
||||
(1024, 2048, 512),
|
||||
(2048, 1024, 2048),
|
||||
]
|
||||
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
print(f" Warmup: {args.warmup}, Iterations: {args.iterations}\n")
|
||||
|
||||
print(f" {'Size':<20} | {'Min (ms)':>10} | {'Avg (ms)':>10} | {'TFLOPS':>10}")
|
||||
print(" " + "-" * 60)
|
||||
|
||||
all_tflops = []
|
||||
|
||||
for M, N, K in sizes:
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
continue
|
||||
|
||||
A = np.random.randn(M, K).astype(np_dtype) * 0.1
|
||||
B = np.random.randn(K, N).astype(np_dtype) * 0.1
|
||||
|
||||
# Warmup
|
||||
for _ in range(args.warmup):
|
||||
dispatcher.run(A, B, M, N, K)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(args.iterations):
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
if result.success:
|
||||
times.append(result.time_ms)
|
||||
|
||||
if times:
|
||||
min_time = min(times)
|
||||
avg_time = sum(times) / len(times)
|
||||
tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12
|
||||
all_tflops.append(tflops)
|
||||
print(
|
||||
f" {M:>4}x{N:>4}x{K:<4} | {min_time:>10.4f} | {avg_time:>10.4f} | {tflops:>10.2f}"
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary")
|
||||
print("=" * 60)
|
||||
|
||||
if all_tflops:
|
||||
print(f" Average: {sum(all_tflops) / len(all_tflops):.2f} TFLOPS")
|
||||
print(f" Peak: {max(all_tflops):.2f} TFLOPS")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user