mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-23 00:27:38 +00:00
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
332 lines
11 KiB
Python
332 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Example 01: Basic GEMM with Multiple Kernels
|
|
|
|
Demonstrates:
|
|
1. Declaring multiple kernel configurations
|
|
2. Printing all registered kernels
|
|
3. Running each kernel and validating output
|
|
4. Comparing performance across kernels
|
|
|
|
Complexity: ★★☆☆☆
|
|
|
|
Usage:
|
|
python3 01_basic_gemm.py
|
|
python3 01_basic_gemm.py --help
|
|
python3 01_basic_gemm.py --dtype bf16
|
|
python3 01_basic_gemm.py --size 2048
|
|
"""
|
|
|
|
import sys
|
|
import argparse
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
from typing import List
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
|
import numpy as np
|
|
|
|
from ctypes_utils import (
|
|
KernelConfig,
|
|
setup_gemm_dispatcher,
|
|
cleanup_gemm,
|
|
reset_for_example,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class KernelSpec:
|
|
"""Specification for a kernel configuration"""
|
|
|
|
name: str
|
|
tile_m: int
|
|
tile_n: int
|
|
tile_k: int
|
|
pipeline: str = "compv3"
|
|
scheduler: str = "intrawave"
|
|
|
|
|
|
# Define multiple kernel configurations to test (50+ kernels)
|
|
KERNEL_SPECS = [
|
|
# Small tiles - compv3
|
|
KernelSpec("small_64x64_k32", 64, 64, 32, "compv3"),
|
|
KernelSpec("small_64x64_k64", 64, 64, 64, "compv3"),
|
|
# Small tiles - compv4
|
|
KernelSpec("small_64x64_v4_k32", 64, 64, 32, "compv4"),
|
|
KernelSpec("small_64x64_v4_k64", 64, 64, 64, "compv4"),
|
|
# Medium tiles - compv3
|
|
KernelSpec("med_128x128_k32", 128, 128, 32, "compv3"),
|
|
KernelSpec("med_128x128_k64", 128, 128, 64, "compv3"),
|
|
KernelSpec("med_128x128_k128", 128, 128, 128, "compv3"),
|
|
# Medium tiles - compv4
|
|
KernelSpec("med_128x128_v4_k32", 128, 128, 32, "compv4"),
|
|
KernelSpec("med_128x128_v4_k64", 128, 128, 64, "compv4"),
|
|
KernelSpec("med_128x128_v4_k128", 128, 128, 128, "compv4"),
|
|
# Rectangular tiles - compv3
|
|
KernelSpec("rect_64x128_k32", 64, 128, 32, "compv3"),
|
|
KernelSpec("rect_64x128_k64", 64, 128, 64, "compv3"),
|
|
KernelSpec("rect_128x64_k32", 128, 64, 32, "compv3"),
|
|
KernelSpec("rect_128x64_k64", 128, 64, 64, "compv3"),
|
|
# Rectangular tiles - compv4
|
|
KernelSpec("rect_64x128_v4_k32", 64, 128, 32, "compv4"),
|
|
KernelSpec("rect_64x128_v4_k64", 64, 128, 64, "compv4"),
|
|
KernelSpec("rect_128x64_v4_k32", 128, 64, 32, "compv4"),
|
|
KernelSpec("rect_128x64_v4_k64", 128, 64, 64, "compv4"),
|
|
# Large tiles - compv3
|
|
KernelSpec("large_256x128_k32", 256, 128, 32, "compv3"),
|
|
KernelSpec("large_256x128_k64", 256, 128, 64, "compv3"),
|
|
KernelSpec("large_128x256_k32", 128, 256, 32, "compv3"),
|
|
KernelSpec("large_128x256_k64", 128, 256, 64, "compv3"),
|
|
KernelSpec("large_256x256_k32", 256, 256, 32, "compv3"),
|
|
KernelSpec("large_256x256_k64", 256, 256, 64, "compv3"),
|
|
# Large tiles - compv4
|
|
KernelSpec("large_256x128_v4_k32", 256, 128, 32, "compv4"),
|
|
KernelSpec("large_256x128_v4_k64", 256, 128, 64, "compv4"),
|
|
KernelSpec("large_128x256_v4_k32", 128, 256, 32, "compv4"),
|
|
KernelSpec("large_128x256_v4_k64", 128, 256, 64, "compv4"),
|
|
KernelSpec("large_256x256_v4_k32", 256, 256, 32, "compv4"),
|
|
KernelSpec("large_256x256_v4_k64", 256, 256, 64, "compv4"),
|
|
# Interwave scheduler variants
|
|
KernelSpec("int_64x64_k32", 64, 64, 32, "compv3", "interwave"),
|
|
KernelSpec("int_128x128_k32", 128, 128, 32, "compv3", "interwave"),
|
|
KernelSpec("int_128x128_k64", 128, 128, 64, "compv3", "interwave"),
|
|
KernelSpec("int_256x128_k32", 256, 128, 32, "compv3", "interwave"),
|
|
# More tile_k variations - compv3
|
|
KernelSpec("med_128x128_k16", 128, 128, 16, "compv3"),
|
|
KernelSpec("rect_64x128_k16", 64, 128, 16, "compv3"),
|
|
KernelSpec("rect_128x64_k16", 128, 64, 16, "compv3"),
|
|
# More tile_k variations - compv4
|
|
KernelSpec("med_128x128_v4_k16", 128, 128, 16, "compv4"),
|
|
KernelSpec("rect_64x128_v4_k16", 64, 128, 16, "compv4"),
|
|
KernelSpec("rect_128x64_v4_k16", 128, 64, 16, "compv4"),
|
|
# Additional rectangular
|
|
KernelSpec("rect_32x64_k32", 32, 64, 32, "compv3"),
|
|
KernelSpec("rect_64x32_k32", 64, 32, 32, "compv3"),
|
|
KernelSpec("rect_32x128_k32", 32, 128, 32, "compv3"),
|
|
KernelSpec("rect_128x32_k32", 128, 32, 32, "compv3"),
|
|
# Additional compv4 variants
|
|
KernelSpec("rect_32x64_v4_k32", 32, 64, 32, "compv4"),
|
|
KernelSpec("rect_64x32_v4_k32", 64, 32, 32, "compv4"),
|
|
KernelSpec("rect_32x128_v4_k32", 32, 128, 32, "compv4"),
|
|
KernelSpec("rect_128x32_v4_k32", 128, 32, 32, "compv4"),
|
|
]
|
|
|
|
|
|
def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
|
|
"""Create a KernelConfig from a spec"""
|
|
# Adjust warp tiles based on tile size
|
|
if spec.tile_m <= 64:
|
|
warp_m, warp_n = 16, 16
|
|
else:
|
|
warp_m, warp_n = 32, 32
|
|
|
|
return KernelConfig(
|
|
dtype_a=dtype,
|
|
dtype_b=dtype,
|
|
dtype_c=dtype,
|
|
dtype_acc="fp32",
|
|
layout_a="row",
|
|
layout_b="col",
|
|
layout_c="row",
|
|
tile_m=spec.tile_m,
|
|
tile_n=spec.tile_n,
|
|
tile_k=spec.tile_k,
|
|
wave_m=2,
|
|
wave_n=2,
|
|
wave_k=1,
|
|
warp_m=warp_m,
|
|
warp_n=warp_n,
|
|
warp_k=16,
|
|
pipeline=spec.pipeline,
|
|
scheduler=spec.scheduler,
|
|
epilogue="cshuffle",
|
|
gfx_arch=arch,
|
|
)
|
|
|
|
|
|
def print_kernel_table(specs: List[KernelSpec], dtype: str):
|
|
"""Print a formatted table of kernel configurations"""
|
|
print("\n" + "=" * 70)
|
|
print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)")
|
|
print("=" * 70)
|
|
print(f"\n {'#':<3} {'Name':<18} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}")
|
|
print(" " + "-" * 68)
|
|
|
|
for i, spec in enumerate(specs, 1):
|
|
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
|
|
print(
|
|
f" {i:<3} {spec.name:<18} {tile:<14} {spec.pipeline:<10} {spec.scheduler:<12}"
|
|
)
|
|
|
|
print(" " + "-" * 68)
|
|
print(f" Data type: {dtype}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Basic GEMM Example with Multiple Kernels",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
python3 01_basic_gemm.py # Default FP16 with 4 kernels
|
|
python3 01_basic_gemm.py --dtype bf16 # BF16 mode
|
|
python3 01_basic_gemm.py --size 2048 # Larger problem size
|
|
python3 01_basic_gemm.py --num-kernels 2 # Test only 2 kernels
|
|
""",
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
default="fp16",
|
|
choices=["fp16", "bf16", "fp32"],
|
|
help="Data type (default: fp16)",
|
|
)
|
|
parser.add_argument(
|
|
"--arch",
|
|
default="gfx942",
|
|
help="Target architecture (default: gfx942)",
|
|
)
|
|
parser.add_argument(
|
|
"--size",
|
|
type=int,
|
|
default=512,
|
|
help="Problem size MxNxK (default: 512)",
|
|
)
|
|
parser.add_argument(
|
|
"--num-kernels",
|
|
type=int,
|
|
default=0,
|
|
help="Number of kernels to test (0 = all)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
reset_for_example()
|
|
|
|
print("=" * 70)
|
|
print("Example 01: Basic GEMM with Multiple Kernels")
|
|
print("=" * 70)
|
|
|
|
# Select kernels to test
|
|
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
|
|
|
|
# =========================================================================
|
|
# Step 1: Print all kernel configurations
|
|
# =========================================================================
|
|
print_kernel_table(specs, args.dtype)
|
|
|
|
# =========================================================================
|
|
# Step 2: Setup and test each kernel
|
|
# =========================================================================
|
|
print("\n" + "=" * 70)
|
|
print(" RUNNING KERNELS")
|
|
print("=" * 70)
|
|
|
|
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
|
M, N, K = args.size, args.size, args.size
|
|
|
|
results = []
|
|
|
|
print(f"\n Problem size: {M}x{N}x{K}\n")
|
|
print(
|
|
f" {'#':<3} {'Name':<18} {'Tile':<14} {'Time (ms)':>10} {'TFLOPS':>10} {'Max Err':>10} {'Status':<8}"
|
|
)
|
|
print(" " + "-" * 78)
|
|
|
|
for i, spec in enumerate(specs, 1):
|
|
# Create unique test data per kernel
|
|
np.random.seed(42 + i * 1000)
|
|
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
|
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
|
|
|
# Create config and setup dispatcher
|
|
config = create_kernel_config(spec, args.dtype, args.arch)
|
|
|
|
setup = setup_gemm_dispatcher(
|
|
config=config,
|
|
registry_name=f"kernel_{spec.name}",
|
|
verbose=False,
|
|
auto_rebuild=True,
|
|
)
|
|
|
|
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
|
|
|
|
if not setup.success:
|
|
print(
|
|
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
|
)
|
|
results.append((spec.name, False, 0, 0, 0))
|
|
cleanup_gemm()
|
|
continue
|
|
|
|
dispatcher = setup.dispatcher
|
|
|
|
# Check if size is supported
|
|
if not dispatcher.is_supported(M, N, K):
|
|
print(
|
|
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'SKIP':<8}"
|
|
)
|
|
results.append((spec.name, False, 0, 0, 0))
|
|
cleanup_gemm()
|
|
continue
|
|
|
|
# Run GEMM
|
|
result = dispatcher.run(A, B, M, N, K)
|
|
|
|
if not result.success:
|
|
print(
|
|
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
|
)
|
|
results.append((spec.name, False, 0, 0, 0))
|
|
cleanup_gemm()
|
|
continue
|
|
|
|
# Validate against NumPy reference
|
|
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
|
|
max_err = np.max(np.abs(result.output - C_ref))
|
|
|
|
# Check if within tolerance
|
|
passed = max_err < 1e-2
|
|
status = "PASS" if passed else "FAIL"
|
|
|
|
print(
|
|
f" {i:<3} {spec.name:<18} {tile:<14} {result.time_ms:>10.4f} {result.tflops:>10.2f} {max_err:>10.2e} {status:<8}"
|
|
)
|
|
results.append((spec.name, passed, result.time_ms, result.tflops, max_err))
|
|
|
|
cleanup_gemm()
|
|
|
|
# =========================================================================
|
|
# Step 3: Summary
|
|
# =========================================================================
|
|
print("\n" + "=" * 70)
|
|
print(" SUMMARY")
|
|
print("=" * 70)
|
|
|
|
passed = sum(1 for r in results if r[1])
|
|
failed = len(results) - passed
|
|
|
|
print(f"\n Results: {passed}/{len(results)} kernels passed")
|
|
print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}")
|
|
|
|
if results:
|
|
valid_results = [r for r in results if r[1]]
|
|
if valid_results:
|
|
best = max(valid_results, key=lambda x: x[3])
|
|
print(f"\n Best kernel: {best[0]} ({best[3]:.2f} TFLOPS)")
|
|
|
|
if failed == 0:
|
|
print("\n *** ALL KERNELS PASSED ***")
|
|
else:
|
|
print(f"\n *** {failed} KERNELS FAILED ***")
|
|
|
|
print("=" * 70)
|
|
|
|
return 0 if failed == 0 else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|