mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Adding dispatcher architecture (#3300)
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
This commit is contained in:
committed by
GitHub
parent
44f481a45c
commit
9e049a32a1
513
dispatcher/examples/gemm/python/07_stress_test.py
Normal file
513
dispatcher/examples/gemm/python/07_stress_test.py
Normal file
@@ -0,0 +1,513 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 07: Stress Test - Multiple Kernels with Validation
|
||||
|
||||
Consolidated stress test that:
|
||||
1. Declares multiple kernel configurations (various tiles, pipelines, layouts)
|
||||
2. Prints all registered kernels with details
|
||||
3. Validates each kernel against NumPy reference
|
||||
4. Optional benchmarking mode
|
||||
|
||||
This tests:
|
||||
- Multiple tile sizes (64x64, 128x128, 256x256)
|
||||
- Multiple pipelines (compv3, compv4)
|
||||
- Multiple data types (fp16, bf16)
|
||||
- Different schedulers (intrawave, interwave)
|
||||
|
||||
Complexity: ★★★★☆
|
||||
|
||||
Usage:
|
||||
python3 07_stress_test.py
|
||||
python3 07_stress_test.py --help
|
||||
python3 07_stress_test.py --num-kernels 10
|
||||
python3 07_stress_test.py --benchmark
|
||||
python3 07_stress_test.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
Validator,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSpec:
|
||||
"""A kernel specification for testing"""
|
||||
|
||||
name: str
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
wave_m: int = 2
|
||||
wave_n: int = 2
|
||||
wave_k: int = 1
|
||||
warp_m: int = 32
|
||||
warp_n: int = 32
|
||||
warp_k: int = 16
|
||||
pipeline: str = "compv3"
|
||||
scheduler: str = "intrawave"
|
||||
layout: str = "rcr"
|
||||
|
||||
def to_config(self, dtype: str, arch: str) -> KernelConfig:
|
||||
"""Convert to KernelConfig"""
|
||||
# Adjust warp tiles for smaller tiles
|
||||
warp_m = min(self.warp_m, self.tile_m // self.wave_m)
|
||||
warp_n = min(self.warp_n, self.tile_n // self.wave_n)
|
||||
warp_k = self.warp_k
|
||||
|
||||
return KernelConfig(
|
||||
dtype_a=dtype,
|
||||
dtype_b=dtype,
|
||||
dtype_c=dtype,
|
||||
dtype_acc="fp32",
|
||||
layout_a={"r": "row", "c": "col"}[self.layout[0]],
|
||||
layout_b={"r": "row", "c": "col"}[self.layout[1]],
|
||||
layout_c={"r": "row", "c": "col"}[self.layout[2]],
|
||||
tile_m=self.tile_m,
|
||||
tile_n=self.tile_n,
|
||||
tile_k=self.tile_k,
|
||||
wave_m=self.wave_m,
|
||||
wave_n=self.wave_n,
|
||||
wave_k=self.wave_k,
|
||||
warp_m=warp_m,
|
||||
warp_n=warp_n,
|
||||
warp_k=warp_k,
|
||||
pipeline=self.pipeline,
|
||||
scheduler=self.scheduler,
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
# Define stress test kernel configurations
|
||||
KERNEL_SPECS = [
|
||||
# Small tiles - compv3
|
||||
KernelSpec(
|
||||
"small_compv3",
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=16,
|
||||
warp_n=16,
|
||||
warp_k=32,
|
||||
pipeline="compv3",
|
||||
),
|
||||
KernelSpec(
|
||||
"small_compv4",
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=16,
|
||||
warp_n=16,
|
||||
warp_k=32,
|
||||
pipeline="compv4",
|
||||
),
|
||||
# Medium tiles
|
||||
KernelSpec(
|
||||
"medium_compv3",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_compv4",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv4",
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_k64",
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
# Rectangular tiles
|
||||
KernelSpec(
|
||||
"rect_64x128",
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
KernelSpec(
|
||||
"rect_128x64",
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
# Different schedulers
|
||||
KernelSpec(
|
||||
"interwave",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
scheduler="interwave",
|
||||
),
|
||||
# Large tiles
|
||||
KernelSpec(
|
||||
"large_compv3",
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
KernelSpec(
|
||||
"large_compv4",
|
||||
256,
|
||||
128,
|
||||
64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv4",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def print_kernel_summary(specs: List[KernelSpec], dtype: str):
|
||||
"""Print a summary table of all kernel specs"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)")
|
||||
print("=" * 80)
|
||||
print(
|
||||
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Wave':<10} {'Warp':<12} {'Pipeline':<10} {'Sched':<10}"
|
||||
)
|
||||
print(" " + "-" * 78)
|
||||
|
||||
for i, spec in enumerate(specs, 1):
|
||||
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
|
||||
wave = f"{spec.wave_m}x{spec.wave_n}x{spec.wave_k}"
|
||||
warp = f"{spec.warp_m}x{spec.warp_n}x{spec.warp_k}"
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {wave:<10} {warp:<12} {spec.pipeline:<10} {spec.scheduler:<10}"
|
||||
)
|
||||
|
||||
print(" " + "-" * 78)
|
||||
print(f" Data type: {dtype}\n")
|
||||
|
||||
|
||||
def validate_kernel(
|
||||
spec: KernelSpec,
|
||||
dtype: str,
|
||||
arch: str,
|
||||
size: int,
|
||||
validator: Validator,
|
||||
kernel_index: int = 0,
|
||||
verbose: bool = False,
|
||||
) -> Tuple[bool, float, str]:
|
||||
"""
|
||||
Validate a single kernel configuration.
|
||||
Returns: (passed, max_error, message)
|
||||
"""
|
||||
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
# Create config
|
||||
config = spec.to_config(dtype, arch)
|
||||
|
||||
# Setup dispatcher
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"stress_{spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
if not setup.success:
|
||||
return False, 0.0, f"Setup failed: {setup.error}"
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
M, N, K = size, size, size
|
||||
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
cleanup_gemm()
|
||||
return False, 0.0, f"Size {M}x{N}x{K} not supported"
|
||||
|
||||
# Use different seed per kernel to get unique test data
|
||||
# This ensures each kernel is tested with different matrices
|
||||
np.random.seed(42 + kernel_index * 1000)
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
# Run GPU GEMM
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if not result.success:
|
||||
cleanup_gemm()
|
||||
return False, 0.0, "GPU execution failed"
|
||||
|
||||
# Validate against NumPy
|
||||
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
|
||||
is_valid, max_err, _ = validator.check(result.output, C_ref)
|
||||
|
||||
cleanup_gemm()
|
||||
|
||||
return is_valid, max_err, f"{result.time_ms:.2f}ms, {result.tflops:.1f} TFLOPS"
|
||||
|
||||
|
||||
def benchmark_kernel(
|
||||
spec: KernelSpec,
|
||||
dtype: str,
|
||||
arch: str,
|
||||
size: int,
|
||||
warmup: int = 3,
|
||||
iterations: int = 10,
|
||||
) -> Tuple[bool, float, float]:
|
||||
"""
|
||||
Benchmark a kernel configuration.
|
||||
Returns: (success, avg_time_ms, tflops)
|
||||
"""
|
||||
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
config = spec.to_config(dtype, arch)
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"bench_{spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
if not setup.success:
|
||||
return False, 0.0, 0.0
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
M, N, K = size, size, size
|
||||
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
cleanup_gemm()
|
||||
return False, 0.0, 0.0
|
||||
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup):
|
||||
dispatcher.run(A, B, M, N, K)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(iterations):
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
if result.success:
|
||||
times.append(result.time_ms)
|
||||
|
||||
cleanup_gemm()
|
||||
|
||||
if not times:
|
||||
return False, 0.0, 0.0
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12
|
||||
|
||||
return True, avg_time, tflops
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM Stress Test - Multiple kernels with validation",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 07_stress_test.py # Test all kernels
|
||||
python3 07_stress_test.py --num-kernels 5 # Test first 5 kernels
|
||||
python3 07_stress_test.py --benchmark # Include benchmarks
|
||||
python3 07_stress_test.py --dtype bf16 # Test BF16
|
||||
python3 07_stress_test.py --size 2048 # Use 2048x2048 matrices
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-kernels",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of kernels to test (0 = all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Problem size MxNxK (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark",
|
||||
action="store_true",
|
||||
help="Include benchmark timing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rtol",
|
||||
type=float,
|
||||
default=1e-2,
|
||||
help="Relative tolerance (default: 1e-2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atol",
|
||||
type=float,
|
||||
default=1e-2,
|
||||
help="Absolute tolerance (default: 1e-2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
default="gfx942",
|
||||
help="Target architecture (default: gfx942)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 80)
|
||||
print("Example 07: GEMM Stress Test - Multiple Kernels")
|
||||
print("=" * 80)
|
||||
|
||||
# Select kernels to test
|
||||
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
|
||||
|
||||
# Print kernel summary
|
||||
print_kernel_summary(specs, args.dtype)
|
||||
|
||||
# Run validation
|
||||
print("\n" + "=" * 80)
|
||||
print(" VALIDATION RESULTS")
|
||||
print("=" * 80)
|
||||
|
||||
validator = Validator(rtol=args.rtol, atol=args.atol)
|
||||
|
||||
if args.benchmark:
|
||||
print(
|
||||
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Time':>10} {'TFLOPS':>8} {'Status':<8}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Info':<25} {'Status':<8}"
|
||||
)
|
||||
print(" " + "-" * 78)
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
|
||||
for i, spec in enumerate(specs, 1):
|
||||
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
|
||||
|
||||
try:
|
||||
is_valid, max_err, info = validate_kernel(
|
||||
spec, args.dtype, args.arch, args.size, validator, kernel_index=i
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
status = "PASS"
|
||||
passed += 1
|
||||
else:
|
||||
status = "FAIL"
|
||||
failed += 1
|
||||
|
||||
if args.benchmark:
|
||||
success, avg_time, tflops = benchmark_kernel(
|
||||
spec, args.dtype, args.arch, args.size
|
||||
)
|
||||
if success:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {avg_time:>9.2f}ms {tflops:>7.1f} {status:<8}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {'N/A':>10} {'N/A':>8} {status:<8}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {info:<25} {status:<8}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
skipped += 1
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {'N/A':>10} {str(e)[:25]:<25} {'SKIP':<8}"
|
||||
)
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 80)
|
||||
print(" SUMMARY")
|
||||
print("=" * 80)
|
||||
total = passed + failed + skipped
|
||||
print(f"\n Results: {passed}/{total} passed, {failed} failed, {skipped} skipped")
|
||||
print(f" Settings: dtype={args.dtype}, size={args.size}x{args.size}x{args.size}")
|
||||
print(f" Tolerance: rtol={args.rtol}, atol={args.atol}")
|
||||
print(f" Architecture: {args.arch}")
|
||||
|
||||
if failed == 0 and skipped == 0:
|
||||
print("\n *** ALL KERNELS PASSED ***")
|
||||
elif failed > 0:
|
||||
print(f"\n *** {failed} KERNELS FAILED ***")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user