Files
composable_kernel/dispatcher/examples/gemm/python/07_stress_test.py
Vidyasagar Ananthan 920acd2c12 [rocm-libraries] ROCm/rocm-libraries#5168 (commit 8b5afcb)
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher

## Motivation

This PR adds CK Tile group convolution (forward, backward-data,
backward-weight) support to the kernel dispatcher, matching and unifying
with the existing dispatcher GEMM infrastructure in architecture and
usability. The dispatcher provides a unified kernel dispatch system with
both C++ and Python frontends, and until now only supported GEMM
operations. This PR enables framework integrators to use the same
declarative kernel workflow for convolutions as they do for GEMM:
declare kernels, build a registry JIT, select kernels within the
registry at runtime, and dispatch to GPU. Future PRs will include
runtime kernel selection heuristics for autotuning of kernel parameters
based on (problem, hardware arch).

## Technical Details

Grouped convolution support has been added to the CK Tile Dispatcher
with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out,
problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime
heuristic kernel selection, and GroupedConvKernelKey with full
ConvConfigBase fields. Python side adds parallel JIT via
registry.build(max_workers) and heuristic registry.select(). Includes 7
C++ and 6 Python examples covering all directions with CPU reference
validation, and shared infrastructure improvements (BaseRegistry CRTP,
structured exceptions). As a sanity check, JIT compile times for a
single kernel remains the same and for multiple kernels there is better
parallelism:
Kernels | 1 worker | 8 workers
1 | 7.7 s | 7.7 s
2 | 15.9 s | 8.2 s
4 | 33.4 s | 9.7 s
6 | 52.3 s | 10.2 s

## Test Plan

145 ephemeral unit tests have been added to test basic functionality.
All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7
C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference
validation for forward, backward-data, and backward-weight (2D) in both
C++ and Python examples pass.

## Test Result

30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56),
53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002
for all directions (fp16 vs fp32 reference).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-09 17:39:35 +00:00

514 lines
13 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 07: Stress Test - Multiple Kernels with Validation
Consolidated stress test that:
1. Declares multiple kernel configurations (various tiles, pipelines, layouts)
2. Prints all registered kernels with details
3. Validates each kernel against NumPy reference
4. Optional benchmarking mode
This tests:
- Multiple tile sizes (64x64, 128x128, 256x256)
- Multiple pipelines (compv3, compv4)
- Multiple data types (fp16, bf16)
- Different schedulers (intrawave, interwave)
Usage:
python3 07_stress_test.py
python3 07_stress_test.py --help
python3 07_stress_test.py --num-kernels 10
python3 07_stress_test.py --benchmark
python3 07_stress_test.py --dtype bf16
"""
import sys
import argparse
from pathlib import Path
from dataclasses import dataclass
from typing import List, Tuple
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
Validator,
detect_gpu_arch,
)
@dataclass
class KernelSpec:
"""A kernel specification for testing"""
name: str
tile_m: int
tile_n: int
tile_k: int
wave_m: int = 2
wave_n: int = 2
wave_k: int = 1
warp_m: int = 32
warp_n: int = 32
warp_k: int = 16
pipeline: str = "compv3"
scheduler: str = "intrawave"
layout: str = "rcr"
def to_config(self, dtype: str, arch: str) -> KernelConfig:
"""Convert to KernelConfig"""
# Adjust warp tiles for smaller tiles
warp_m = min(self.warp_m, self.tile_m // self.wave_m)
warp_n = min(self.warp_n, self.tile_n // self.wave_n)
warp_k = self.warp_k
return KernelConfig(
dtype_a=dtype,
dtype_b=dtype,
dtype_c=dtype,
dtype_acc="fp32",
layout_a={"r": "row", "c": "col"}[self.layout[0]],
layout_b={"r": "row", "c": "col"}[self.layout[1]],
layout_c={"r": "row", "c": "col"}[self.layout[2]],
tile_m=self.tile_m,
tile_n=self.tile_n,
tile_k=self.tile_k,
wave_m=self.wave_m,
wave_n=self.wave_n,
wave_k=self.wave_k,
warp_m=warp_m,
warp_n=warp_n,
warp_k=warp_k,
pipeline=self.pipeline,
scheduler=self.scheduler,
epilogue="cshuffle",
gfx_arch=arch,
)
# Define stress test kernel configurations
KERNEL_SPECS = [
# Small tiles - compv3
KernelSpec(
"small_compv3",
64,
64,
32,
wave_m=2,
wave_n=2,
warp_m=16,
warp_n=16,
warp_k=32,
pipeline="compv3",
),
KernelSpec(
"small_compv4",
64,
64,
32,
wave_m=2,
wave_n=2,
warp_m=16,
warp_n=16,
warp_k=32,
pipeline="compv4",
),
# Medium tiles
KernelSpec(
"medium_compv3",
128,
128,
32,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv3",
),
KernelSpec(
"medium_compv4",
128,
128,
32,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv4",
),
KernelSpec(
"medium_k64",
128,
128,
64,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv3",
),
# Rectangular tiles
KernelSpec(
"rect_64x128",
64,
128,
32,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv3",
),
KernelSpec(
"rect_128x64",
128,
64,
32,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv3",
),
# Different schedulers
KernelSpec(
"interwave",
128,
128,
32,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv3",
scheduler="interwave",
),
# Large tiles
KernelSpec(
"large_compv3",
256,
128,
32,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv3",
),
KernelSpec(
"large_compv4",
256,
128,
64,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv4",
),
]
def print_kernel_summary(specs: List[KernelSpec], dtype: str):
"""Print a summary table of all kernel specs"""
print("\n" + "=" * 80)
print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)")
print("=" * 80)
print(
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Wave':<10} {'Warp':<12} {'Pipeline':<10} {'Sched':<10}"
)
print(" " + "-" * 78)
for i, spec in enumerate(specs, 1):
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
wave = f"{spec.wave_m}x{spec.wave_n}x{spec.wave_k}"
warp = f"{spec.warp_m}x{spec.warp_n}x{spec.warp_k}"
print(
f" {i:<3} {spec.name:<18} {tile:<12} {wave:<10} {warp:<12} {spec.pipeline:<10} {spec.scheduler:<10}"
)
print(" " + "-" * 78)
print(f" Data type: {dtype}\n")
def validate_kernel(
spec: KernelSpec,
dtype: str,
arch: str,
size: int,
validator: Validator,
kernel_index: int = 0,
verbose: bool = False,
) -> Tuple[bool, float, str]:
"""
Validate a single kernel configuration.
Returns: (passed, max_error, message)
"""
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
# Create config
config = spec.to_config(dtype, arch)
# Setup dispatcher
setup = setup_gemm_dispatcher(
config=config,
registry_name=f"stress_{spec.name}",
verbose=False,
auto_rebuild=True,
)
if not setup.success:
return False, 0.0, f"Setup failed: {setup.error}"
dispatcher = setup.dispatcher
M, N, K = size, size, size
if not dispatcher.is_supported(M, N, K):
cleanup_gemm()
return False, 0.0, f"Size {M}x{N}x{K} not supported"
# Use different seed per kernel to get unique test data
# This ensures each kernel is tested with different matrices
np.random.seed(42 + kernel_index * 1000)
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
# Run GPU GEMM
result = dispatcher.run(A, B, M, N, K)
if not result.success:
cleanup_gemm()
return False, 0.0, "GPU execution failed"
# Validate against NumPy
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
is_valid, max_err, _ = validator.check(result.output, C_ref)
cleanup_gemm()
return is_valid, max_err, f"{result.time_ms:.2f}ms, {result.tflops:.1f} TFLOPS"
def benchmark_kernel(
spec: KernelSpec,
dtype: str,
arch: str,
size: int,
warmup: int = 3,
iterations: int = 10,
) -> Tuple[bool, float, float]:
"""
Benchmark a kernel configuration.
Returns: (success, avg_time_ms, tflops)
"""
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
config = spec.to_config(dtype, arch)
setup = setup_gemm_dispatcher(
config=config,
registry_name=f"bench_{spec.name}",
verbose=False,
auto_rebuild=True,
)
if not setup.success:
return False, 0.0, 0.0
dispatcher = setup.dispatcher
M, N, K = size, size, size
if not dispatcher.is_supported(M, N, K):
cleanup_gemm()
return False, 0.0, 0.0
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
# Warmup
for _ in range(warmup):
dispatcher.run(A, B, M, N, K)
# Benchmark
times = []
for _ in range(iterations):
result = dispatcher.run(A, B, M, N, K)
if result.success:
times.append(result.time_ms)
cleanup_gemm()
if not times:
return False, 0.0, 0.0
avg_time = sum(times) / len(times)
tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12
return True, avg_time, tflops
def main():
parser = argparse.ArgumentParser(
description="GEMM Stress Test - Multiple kernels with validation",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 07_stress_test.py # Test all kernels
python3 07_stress_test.py --num-kernels 5 # Test first 5 kernels
python3 07_stress_test.py --benchmark # Include benchmarks
python3 07_stress_test.py --dtype bf16 # Test BF16
python3 07_stress_test.py --size 2048 # Use 2048x2048 matrices
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--num-kernels",
type=int,
default=0,
help="Number of kernels to test (0 = all)",
)
parser.add_argument(
"--size",
type=int,
default=512,
help="Problem size MxNxK (default: 512)",
)
parser.add_argument(
"--benchmark",
action="store_true",
help="Include benchmark timing",
)
parser.add_argument(
"--rtol",
type=float,
default=1e-2,
help="Relative tolerance (default: 1e-2)",
)
parser.add_argument(
"--atol",
type=float,
default=1e-2,
help="Absolute tolerance (default: 1e-2)",
)
parser.add_argument(
"--arch",
default=detect_gpu_arch(),
help="Target architecture (auto-detected from rocminfo, override with --arch gfxNNN)",
)
args = parser.parse_args()
reset_for_example()
print("=" * 80)
print("Example 07: GEMM Stress Test - Multiple Kernels")
print("=" * 80)
# Select kernels to test
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
# Print kernel summary
print_kernel_summary(specs, args.dtype)
# Run validation
print("\n" + "=" * 80)
print(" VALIDATION RESULTS")
print("=" * 80)
validator = Validator(rtol=args.rtol, atol=args.atol)
if args.benchmark:
print(
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Time':>10} {'TFLOPS':>8} {'Status':<8}"
)
else:
print(
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Info':<25} {'Status':<8}"
)
print(" " + "-" * 78)
passed = 0
failed = 0
skipped = 0
for i, spec in enumerate(specs, 1):
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
try:
is_valid, max_err, info = validate_kernel(
spec, args.dtype, args.arch, args.size, validator, kernel_index=i
)
if is_valid:
status = "PASS"
passed += 1
else:
status = "FAIL"
failed += 1
if args.benchmark:
success, avg_time, tflops = benchmark_kernel(
spec, args.dtype, args.arch, args.size
)
if success:
print(
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {avg_time:>9.2f}ms {tflops:>7.1f} {status:<8}"
)
else:
print(
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {'N/A':>10} {'N/A':>8} {status:<8}"
)
else:
print(
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {info:<25} {status:<8}"
)
except Exception as e:
skipped += 1
print(
f" {i:<3} {spec.name:<18} {tile:<12} {'N/A':>10} {str(e)[:25]:<25} {'SKIP':<8}"
)
# Summary
print("\n" + "=" * 80)
print(" SUMMARY")
print("=" * 80)
total = passed + failed + skipped
print(f"\n Results: {passed}/{total} passed, {failed} failed, {skipped} skipped")
print(f" Settings: dtype={args.dtype}, size={args.size}x{args.size}x{args.size}")
print(f" Tolerance: rtol={args.rtol}, atol={args.atol}")
print(f" Architecture: {args.arch}")
if failed == 0 and skipped == 0:
print("\n *** ALL KERNELS PASSED ***")
elif failed > 0:
print(f"\n *** {failed} KERNELS FAILED ***")
print("=" * 80)
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())