mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Adding dispatcher architecture (#3300)
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
This commit is contained in:
committed by
GitHub
parent
44f481a45c
commit
9e049a32a1
331
dispatcher/examples/gemm/python/01_basic_gemm.py
Normal file
331
dispatcher/examples/gemm/python/01_basic_gemm.py
Normal file
@@ -0,0 +1,331 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 01: Basic GEMM with Multiple Kernels
|
||||
|
||||
Demonstrates:
|
||||
1. Declaring multiple kernel configurations
|
||||
2. Printing all registered kernels
|
||||
3. Running each kernel and validating output
|
||||
4. Comparing performance across kernels
|
||||
|
||||
Complexity: ★★☆☆☆
|
||||
|
||||
Usage:
|
||||
python3 01_basic_gemm.py
|
||||
python3 01_basic_gemm.py --help
|
||||
python3 01_basic_gemm.py --dtype bf16
|
||||
python3 01_basic_gemm.py --size 2048
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSpec:
|
||||
"""Specification for a kernel configuration"""
|
||||
|
||||
name: str
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
pipeline: str = "compv3"
|
||||
scheduler: str = "intrawave"
|
||||
|
||||
|
||||
# Define multiple kernel configurations to test (50+ kernels)
|
||||
KERNEL_SPECS = [
|
||||
# Small tiles - compv3
|
||||
KernelSpec("small_64x64_k32", 64, 64, 32, "compv3"),
|
||||
KernelSpec("small_64x64_k64", 64, 64, 64, "compv3"),
|
||||
# Small tiles - compv4
|
||||
KernelSpec("small_64x64_v4_k32", 64, 64, 32, "compv4"),
|
||||
KernelSpec("small_64x64_v4_k64", 64, 64, 64, "compv4"),
|
||||
# Medium tiles - compv3
|
||||
KernelSpec("med_128x128_k32", 128, 128, 32, "compv3"),
|
||||
KernelSpec("med_128x128_k64", 128, 128, 64, "compv3"),
|
||||
KernelSpec("med_128x128_k128", 128, 128, 128, "compv3"),
|
||||
# Medium tiles - compv4
|
||||
KernelSpec("med_128x128_v4_k32", 128, 128, 32, "compv4"),
|
||||
KernelSpec("med_128x128_v4_k64", 128, 128, 64, "compv4"),
|
||||
KernelSpec("med_128x128_v4_k128", 128, 128, 128, "compv4"),
|
||||
# Rectangular tiles - compv3
|
||||
KernelSpec("rect_64x128_k32", 64, 128, 32, "compv3"),
|
||||
KernelSpec("rect_64x128_k64", 64, 128, 64, "compv3"),
|
||||
KernelSpec("rect_128x64_k32", 128, 64, 32, "compv3"),
|
||||
KernelSpec("rect_128x64_k64", 128, 64, 64, "compv3"),
|
||||
# Rectangular tiles - compv4
|
||||
KernelSpec("rect_64x128_v4_k32", 64, 128, 32, "compv4"),
|
||||
KernelSpec("rect_64x128_v4_k64", 64, 128, 64, "compv4"),
|
||||
KernelSpec("rect_128x64_v4_k32", 128, 64, 32, "compv4"),
|
||||
KernelSpec("rect_128x64_v4_k64", 128, 64, 64, "compv4"),
|
||||
# Large tiles - compv3
|
||||
KernelSpec("large_256x128_k32", 256, 128, 32, "compv3"),
|
||||
KernelSpec("large_256x128_k64", 256, 128, 64, "compv3"),
|
||||
KernelSpec("large_128x256_k32", 128, 256, 32, "compv3"),
|
||||
KernelSpec("large_128x256_k64", 128, 256, 64, "compv3"),
|
||||
KernelSpec("large_256x256_k32", 256, 256, 32, "compv3"),
|
||||
KernelSpec("large_256x256_k64", 256, 256, 64, "compv3"),
|
||||
# Large tiles - compv4
|
||||
KernelSpec("large_256x128_v4_k32", 256, 128, 32, "compv4"),
|
||||
KernelSpec("large_256x128_v4_k64", 256, 128, 64, "compv4"),
|
||||
KernelSpec("large_128x256_v4_k32", 128, 256, 32, "compv4"),
|
||||
KernelSpec("large_128x256_v4_k64", 128, 256, 64, "compv4"),
|
||||
KernelSpec("large_256x256_v4_k32", 256, 256, 32, "compv4"),
|
||||
KernelSpec("large_256x256_v4_k64", 256, 256, 64, "compv4"),
|
||||
# Interwave scheduler variants
|
||||
KernelSpec("int_64x64_k32", 64, 64, 32, "compv3", "interwave"),
|
||||
KernelSpec("int_128x128_k32", 128, 128, 32, "compv3", "interwave"),
|
||||
KernelSpec("int_128x128_k64", 128, 128, 64, "compv3", "interwave"),
|
||||
KernelSpec("int_256x128_k32", 256, 128, 32, "compv3", "interwave"),
|
||||
# More tile_k variations - compv3
|
||||
KernelSpec("med_128x128_k16", 128, 128, 16, "compv3"),
|
||||
KernelSpec("rect_64x128_k16", 64, 128, 16, "compv3"),
|
||||
KernelSpec("rect_128x64_k16", 128, 64, 16, "compv3"),
|
||||
# More tile_k variations - compv4
|
||||
KernelSpec("med_128x128_v4_k16", 128, 128, 16, "compv4"),
|
||||
KernelSpec("rect_64x128_v4_k16", 64, 128, 16, "compv4"),
|
||||
KernelSpec("rect_128x64_v4_k16", 128, 64, 16, "compv4"),
|
||||
# Additional rectangular
|
||||
KernelSpec("rect_32x64_k32", 32, 64, 32, "compv3"),
|
||||
KernelSpec("rect_64x32_k32", 64, 32, 32, "compv3"),
|
||||
KernelSpec("rect_32x128_k32", 32, 128, 32, "compv3"),
|
||||
KernelSpec("rect_128x32_k32", 128, 32, 32, "compv3"),
|
||||
# Additional compv4 variants
|
||||
KernelSpec("rect_32x64_v4_k32", 32, 64, 32, "compv4"),
|
||||
KernelSpec("rect_64x32_v4_k32", 64, 32, 32, "compv4"),
|
||||
KernelSpec("rect_32x128_v4_k32", 32, 128, 32, "compv4"),
|
||||
KernelSpec("rect_128x32_v4_k32", 128, 32, 32, "compv4"),
|
||||
]
|
||||
|
||||
|
||||
def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
|
||||
"""Create a KernelConfig from a spec"""
|
||||
# Adjust warp tiles based on tile size
|
||||
if spec.tile_m <= 64:
|
||||
warp_m, warp_n = 16, 16
|
||||
else:
|
||||
warp_m, warp_n = 32, 32
|
||||
|
||||
return KernelConfig(
|
||||
dtype_a=dtype,
|
||||
dtype_b=dtype,
|
||||
dtype_c=dtype,
|
||||
dtype_acc="fp32",
|
||||
layout_a="row",
|
||||
layout_b="col",
|
||||
layout_c="row",
|
||||
tile_m=spec.tile_m,
|
||||
tile_n=spec.tile_n,
|
||||
tile_k=spec.tile_k,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
warp_m=warp_m,
|
||||
warp_n=warp_n,
|
||||
warp_k=16,
|
||||
pipeline=spec.pipeline,
|
||||
scheduler=spec.scheduler,
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
def print_kernel_table(specs: List[KernelSpec], dtype: str):
|
||||
"""Print a formatted table of kernel configurations"""
|
||||
print("\n" + "=" * 70)
|
||||
print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)")
|
||||
print("=" * 70)
|
||||
print(f"\n {'#':<3} {'Name':<18} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}")
|
||||
print(" " + "-" * 68)
|
||||
|
||||
for i, spec in enumerate(specs, 1):
|
||||
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<14} {spec.pipeline:<10} {spec.scheduler:<12}"
|
||||
)
|
||||
|
||||
print(" " + "-" * 68)
|
||||
print(f" Data type: {dtype}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Basic GEMM Example with Multiple Kernels",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 01_basic_gemm.py # Default FP16 with 4 kernels
|
||||
python3 01_basic_gemm.py --dtype bf16 # BF16 mode
|
||||
python3 01_basic_gemm.py --size 2048 # Larger problem size
|
||||
python3 01_basic_gemm.py --num-kernels 2 # Test only 2 kernels
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
default="gfx942",
|
||||
help="Target architecture (default: gfx942)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Problem size MxNxK (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-kernels",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of kernels to test (0 = all)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 01: Basic GEMM with Multiple Kernels")
|
||||
print("=" * 70)
|
||||
|
||||
# Select kernels to test
|
||||
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Print all kernel configurations
|
||||
# =========================================================================
|
||||
print_kernel_table(specs, args.dtype)
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Setup and test each kernel
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 70)
|
||||
print(" RUNNING KERNELS")
|
||||
print("=" * 70)
|
||||
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
M, N, K = args.size, args.size, args.size
|
||||
|
||||
results = []
|
||||
|
||||
print(f"\n Problem size: {M}x{N}x{K}\n")
|
||||
print(
|
||||
f" {'#':<3} {'Name':<18} {'Tile':<14} {'Time (ms)':>10} {'TFLOPS':>10} {'Max Err':>10} {'Status':<8}"
|
||||
)
|
||||
print(" " + "-" * 78)
|
||||
|
||||
for i, spec in enumerate(specs, 1):
|
||||
# Create unique test data per kernel
|
||||
np.random.seed(42 + i * 1000)
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
# Create config and setup dispatcher
|
||||
config = create_kernel_config(spec, args.dtype, args.arch)
|
||||
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"kernel_{spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
|
||||
|
||||
if not setup.success:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
||||
)
|
||||
results.append((spec.name, False, 0, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
|
||||
# Check if size is supported
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'SKIP':<8}"
|
||||
)
|
||||
results.append((spec.name, False, 0, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
# Run GEMM
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if not result.success:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
||||
)
|
||||
results.append((spec.name, False, 0, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
# Validate against NumPy reference
|
||||
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
|
||||
max_err = np.max(np.abs(result.output - C_ref))
|
||||
|
||||
# Check if within tolerance
|
||||
passed = max_err < 1e-2
|
||||
status = "PASS" if passed else "FAIL"
|
||||
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<14} {result.time_ms:>10.4f} {result.tflops:>10.2f} {max_err:>10.2e} {status:<8}"
|
||||
)
|
||||
results.append((spec.name, passed, result.time_ms, result.tflops, max_err))
|
||||
|
||||
cleanup_gemm()
|
||||
|
||||
# =========================================================================
|
||||
# Step 3: Summary
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 70)
|
||||
print(" SUMMARY")
|
||||
print("=" * 70)
|
||||
|
||||
passed = sum(1 for r in results if r[1])
|
||||
failed = len(results) - passed
|
||||
|
||||
print(f"\n Results: {passed}/{len(results)} kernels passed")
|
||||
print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}")
|
||||
|
||||
if results:
|
||||
valid_results = [r for r in results if r[1]]
|
||||
if valid_results:
|
||||
best = max(valid_results, key=lambda x: x[3])
|
||||
print(f"\n Best kernel: {best[0]} ({best[3]:.2f} TFLOPS)")
|
||||
|
||||
if failed == 0:
|
||||
print("\n *** ALL KERNELS PASSED ***")
|
||||
else:
|
||||
print(f"\n *** {failed} KERNELS FAILED ***")
|
||||
|
||||
print("=" * 70)
|
||||
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
149
dispatcher/examples/gemm/python/02_batch_gemm.py
Normal file
149
dispatcher/examples/gemm/python/02_batch_gemm.py
Normal file
@@ -0,0 +1,149 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 02: Batch GEMM
|
||||
|
||||
Runs multiple GEMM operations with different sizes.
|
||||
|
||||
Complexity: ★★☆☆☆
|
||||
|
||||
Usage:
|
||||
python3 02_batch_gemm.py
|
||||
python3 02_batch_gemm.py --help
|
||||
python3 02_batch_gemm.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Batch GEMM Example - runs multiple sizes",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 02_batch_gemm.py # Default FP16
|
||||
python3 02_batch_gemm.py --dtype bf16 # BF16 GEMM
|
||||
python3 02_batch_gemm.py --max-size 2048 # Limit max size
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-size",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="Maximum problem size (default: 4096)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 02: Batch GEMM")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Setup dispatcher
|
||||
# =========================================================================
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(config, registry_name="batch_gemm", verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Run batch of different sizes
|
||||
# =========================================================================
|
||||
print("\nStep 2: Run Batch")
|
||||
|
||||
# Generate sizes up to max_size
|
||||
all_sizes = [
|
||||
(256, 256, 256),
|
||||
(512, 512, 512),
|
||||
(1024, 1024, 1024),
|
||||
(2048, 2048, 2048),
|
||||
(4096, 4096, 4096),
|
||||
]
|
||||
sizes = [(m, n, k) for m, n, k in all_sizes if max(m, n, k) <= args.max_size]
|
||||
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
print(f"\n {'Size':<20} | {'Time (ms)':>12} | {'TFLOPS':>10} | {'Status':>8}")
|
||||
print(" " + "-" * 60)
|
||||
|
||||
total_ops = 0
|
||||
total_time = 0
|
||||
|
||||
for M, N, K in sizes:
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
print(f" {M:>4}x{N:>4}x{K:<4} | {'N/A':>12} | {'N/A':>10} | Skipped")
|
||||
continue
|
||||
|
||||
A = np.random.randn(M, K).astype(np_dtype) * 0.1
|
||||
B = np.random.randn(K, N).astype(np_dtype) * 0.1
|
||||
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if result.success:
|
||||
total_ops += 2 * M * N * K
|
||||
total_time += result.time_ms
|
||||
print(
|
||||
f" {M:>4}x{N:>4}x{K:<4} | {result.time_ms:>12.4f} | {result.tflops:>10.2f} | OK"
|
||||
)
|
||||
else:
|
||||
print(f" {M:>4}x{N:>4}x{K:<4} | {'N/A':>12} | {'N/A':>10} | Error")
|
||||
|
||||
print(" " + "-" * 60)
|
||||
|
||||
if total_time > 0:
|
||||
avg_tflops = (total_ops / 1e12) / (total_time / 1000)
|
||||
print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS")
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Batch GEMM complete!")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
171
dispatcher/examples/gemm/python/03_benchmark.py
Normal file
171
dispatcher/examples/gemm/python/03_benchmark.py
Normal file
@@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 03: Benchmark
|
||||
|
||||
Performance benchmarking with compute-optimized kernel configuration.
|
||||
|
||||
Complexity: ★★★☆☆
|
||||
|
||||
Usage:
|
||||
python3 03_benchmark.py
|
||||
python3 03_benchmark.py --help
|
||||
python3 03_benchmark.py --size 4096
|
||||
python3 03_benchmark.py --dtype bf16 --iterations 20
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM Benchmark Example - performance testing",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 03_benchmark.py # Default benchmark suite
|
||||
python3 03_benchmark.py --size 4096 # Single size benchmark
|
||||
python3 03_benchmark.py --dtype bf16 # BF16 benchmark
|
||||
python3 03_benchmark.py --iterations 20 # More iterations
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="bf16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: bf16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Single problem size MxNxK (default: run all sizes)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup", type=int, default=3, help="Warmup iterations (default: 3)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations", type=int, default=10, help="Benchmark iterations (default: 10)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 03: Benchmark")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Setup dispatcher with compute-optimized config
|
||||
# =========================================================================
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
pipeline="compv4",
|
||||
scheduler="intrawave",
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(config, registry_name="benchmark", verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Benchmark
|
||||
# =========================================================================
|
||||
print("\nStep 2: Benchmark")
|
||||
|
||||
if args.size > 0:
|
||||
sizes = [(args.size, args.size, args.size)]
|
||||
else:
|
||||
sizes = [
|
||||
(512, 512, 512),
|
||||
(1024, 1024, 1024),
|
||||
(2048, 2048, 2048),
|
||||
(4096, 4096, 4096),
|
||||
(1024, 2048, 512),
|
||||
(2048, 1024, 2048),
|
||||
]
|
||||
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
print(f" Warmup: {args.warmup}, Iterations: {args.iterations}\n")
|
||||
|
||||
print(f" {'Size':<20} | {'Min (ms)':>10} | {'Avg (ms)':>10} | {'TFLOPS':>10}")
|
||||
print(" " + "-" * 60)
|
||||
|
||||
all_tflops = []
|
||||
|
||||
for M, N, K in sizes:
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
continue
|
||||
|
||||
A = np.random.randn(M, K).astype(np_dtype) * 0.1
|
||||
B = np.random.randn(K, N).astype(np_dtype) * 0.1
|
||||
|
||||
# Warmup
|
||||
for _ in range(args.warmup):
|
||||
dispatcher.run(A, B, M, N, K)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(args.iterations):
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
if result.success:
|
||||
times.append(result.time_ms)
|
||||
|
||||
if times:
|
||||
min_time = min(times)
|
||||
avg_time = sum(times) / len(times)
|
||||
tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12
|
||||
all_tflops.append(tflops)
|
||||
print(
|
||||
f" {M:>4}x{N:>4}x{K:<4} | {min_time:>10.4f} | {avg_time:>10.4f} | {tflops:>10.2f}"
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary")
|
||||
print("=" * 60)
|
||||
|
||||
if all_tflops:
|
||||
print(f" Average: {sum(all_tflops) / len(all_tflops):.2f} TFLOPS")
|
||||
print(f" Peak: {max(all_tflops):.2f} TFLOPS")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
156
dispatcher/examples/gemm/python/04_validation.py
Normal file
156
dispatcher/examples/gemm/python/04_validation.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 04: Validation
|
||||
|
||||
Validates GPU GEMM against NumPy reference.
|
||||
|
||||
Complexity: ★★★☆☆
|
||||
|
||||
Usage:
|
||||
python3 04_validation.py
|
||||
python3 04_validation.py --help
|
||||
python3 04_validation.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
Validator,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM Validation Example - validates GPU results against NumPy",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 04_validation.py # Default FP16 validation
|
||||
python3 04_validation.py --dtype bf16 # BF16 validation
|
||||
python3 04_validation.py --rtol 1e-2 # Relaxed tolerance
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rtol", type=float, default=1e-3, help="Relative tolerance (default: 1e-3)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 04: Validation")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Setup dispatcher
|
||||
# =========================================================================
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(config, registry_name="validation", verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Run validation tests
|
||||
# =========================================================================
|
||||
print("\nStep 2: Validation Tests")
|
||||
|
||||
validator = Validator(rtol=args.rtol, atol=args.atol)
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
test_cases = [
|
||||
("Identity", 128, 128, 128, "identity"),
|
||||
("Small", 256, 256, 256, "random"),
|
||||
("Medium", 512, 512, 512, "random"),
|
||||
("Large", 1024, 1024, 1024, "random"),
|
||||
("Non-square", 512, 1024, 256, "random"),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
print(f"\n {'Test':<15} | {'Size':<15} | {'Max Err':>10} | {'Status':>8}")
|
||||
print(" " + "-" * 55)
|
||||
|
||||
for name, M, N, K, pattern in test_cases:
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
print(f" {name:<15} | {M}x{N}x{K:<5} | {'N/A':>10} | Skipped")
|
||||
continue
|
||||
|
||||
np.random.seed(42)
|
||||
if pattern == "identity":
|
||||
A = np.eye(M, K, dtype=np_dtype)
|
||||
B = np.eye(K, N, dtype=np_dtype)
|
||||
else:
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
if not result.success:
|
||||
print(f" {name:<15} | {M}x{N}x{K:<5} | {'GPU Err':>10} | FAILED")
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
|
||||
is_valid, max_err, _ = validator.check(result.output, C_ref)
|
||||
|
||||
if is_valid:
|
||||
print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | PASSED")
|
||||
passed += 1
|
||||
else:
|
||||
print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | FAILED")
|
||||
failed += 1
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
total = passed + failed
|
||||
print(f"Results: {passed}/{total} passed")
|
||||
print(f"Settings: dtype={args.dtype}, rtol={args.rtol}, atol={args.atol}")
|
||||
print("=" * 60)
|
||||
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
166
dispatcher/examples/gemm/python/05_numpy_integration.py
Normal file
166
dispatcher/examples/gemm/python/05_numpy_integration.py
Normal file
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 05: NumPy Integration
|
||||
|
||||
Shows how to create a GPU-accelerated matmul wrapper.
|
||||
|
||||
Complexity: ★★☆☆☆
|
||||
|
||||
Usage:
|
||||
python3 05_numpy_integration.py
|
||||
python3 05_numpy_integration.py --help
|
||||
python3 05_numpy_integration.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
Dispatcher,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
class GPUMatmul:
|
||||
"""GPU-accelerated matrix multiplication wrapper."""
|
||||
|
||||
def __init__(self, dispatcher: Dispatcher):
|
||||
self.dispatcher = dispatcher
|
||||
|
||||
def __call__(self, A: np.ndarray, B: np.ndarray) -> np.ndarray:
|
||||
"""Compute C = A @ B on GPU with CPU fallback."""
|
||||
M, K = A.shape
|
||||
K2, N = B.shape
|
||||
|
||||
if K != K2:
|
||||
raise ValueError(f"Dimension mismatch: {A.shape} @ {B.shape}")
|
||||
|
||||
if not self.dispatcher.is_supported(M, N, K):
|
||||
return np.matmul(A, B)
|
||||
|
||||
result = self.dispatcher.run(A, B, M, N, K)
|
||||
return result.output if result.success else np.matmul(A, B)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="NumPy Integration Example - GPU-accelerated matmul wrapper",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 05_numpy_integration.py # Default FP16
|
||||
python3 05_numpy_integration.py --dtype bf16 # BF16 mode
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 05: NumPy Integration")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Setup dispatcher
|
||||
# =========================================================================
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(config, registry_name="numpy", verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Create GPU matmul wrapper
|
||||
# =========================================================================
|
||||
print("\nStep 2: Create GPUMatmul")
|
||||
|
||||
gpu_matmul = GPUMatmul(dispatcher=dispatcher)
|
||||
print(" gpu_matmul ready")
|
||||
|
||||
# =========================================================================
|
||||
# Step 3: Demo - Simple multiplication using gpu_matmul
|
||||
# =========================================================================
|
||||
print("\nStep 3: Demo - Simple Multiplication")
|
||||
|
||||
A = np.random.randn(1024, 512).astype(np_dtype) * 0.1
|
||||
B = np.random.randn(512, 256).astype(np_dtype) * 0.1
|
||||
|
||||
# Use the gpu_matmul wrapper
|
||||
C = gpu_matmul(A, B)
|
||||
print(f" gpu_matmul result: {C.shape}, sum={C.sum():.4f}")
|
||||
|
||||
M, K = A.shape
|
||||
_, N = B.shape
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
print(f" A: {A.shape}, B: {B.shape} -> C: {result.output.shape}")
|
||||
print(f" GPU: {result.time_ms:.4f} ms, {result.tflops:.2f} TFLOPS")
|
||||
|
||||
# =========================================================================
|
||||
# Step 4: Demo - FFN block
|
||||
# =========================================================================
|
||||
print("\nStep 4: Demo - FFN Block")
|
||||
|
||||
batch, hidden, ffn = 128, 768, 3072
|
||||
X = np.random.randn(batch, hidden).astype(np_dtype) * 0.02
|
||||
W1 = np.random.randn(hidden, ffn).astype(np_dtype) * 0.02
|
||||
W2 = np.random.randn(ffn, hidden).astype(np_dtype) * 0.02
|
||||
|
||||
result1 = dispatcher.run(X, W1, batch, ffn, hidden)
|
||||
H = result1.output
|
||||
result2 = dispatcher.run(H, W2, batch, hidden, ffn)
|
||||
|
||||
print(f" X: {X.shape} -> H: {H.shape} -> Y: {result2.output.shape}")
|
||||
print(f" Total: {result1.time_ms + result2.time_ms:.4f} ms")
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("NumPy Integration Pattern:")
|
||||
print("=" * 60)
|
||||
print(" 1. setup_gemm_dispatcher(config)")
|
||||
print(" 2. GPUMatmul(dispatcher)")
|
||||
print(" 3. C = gpu_matmul(A, B)")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
169
dispatcher/examples/gemm/python/06_json_export.py
Normal file
169
dispatcher/examples/gemm/python/06_json_export.py
Normal file
@@ -0,0 +1,169 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 06: JSON Export
|
||||
|
||||
Exports registry configuration to JSON.
|
||||
|
||||
Complexity: ★★☆☆☆
|
||||
|
||||
Usage:
|
||||
python3 06_json_export.py
|
||||
python3 06_json_export.py --help
|
||||
python3 06_json_export.py --output my_kernels.json
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="JSON Export Example - exports registry to JSON",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 06_json_export.py # Default output to kernels.json
|
||||
python3 06_json_export.py --output my.json # Custom output file
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
default="kernels.json",
|
||||
help="Output JSON file (default: kernels.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 06: JSON Export")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Setup dispatcher
|
||||
# =========================================================================
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(config, registry_name="export_demo", verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Define additional configs for export
|
||||
# =========================================================================
|
||||
print("\nStep 2: Define Additional Configs")
|
||||
|
||||
configs = [
|
||||
config,
|
||||
KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=256,
|
||||
tile_n=256,
|
||||
tile_k=64,
|
||||
gfx_arch=args.arch,
|
||||
),
|
||||
KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=64,
|
||||
tile_n=64,
|
||||
tile_k=32,
|
||||
gfx_arch=args.arch,
|
||||
),
|
||||
]
|
||||
|
||||
for cfg in configs:
|
||||
print(f" - {cfg.tile_str}")
|
||||
|
||||
# =========================================================================
|
||||
# Step 3: Export to JSON
|
||||
# =========================================================================
|
||||
print("\nStep 3: Export to JSON")
|
||||
|
||||
export_data = {
|
||||
"registry": setup.registry.name,
|
||||
"kernel_count": len(configs),
|
||||
"kernels": [],
|
||||
}
|
||||
|
||||
for cfg in configs:
|
||||
kernel_info = {
|
||||
"tile": cfg.tile_str,
|
||||
"dtypes": {"A": cfg.dtype_a, "B": cfg.dtype_b, "C": cfg.dtype_c},
|
||||
"layout": cfg.layout,
|
||||
"pipeline": cfg.pipeline,
|
||||
"target": cfg.gfx_arch,
|
||||
}
|
||||
export_data["kernels"].append(kernel_info)
|
||||
|
||||
# Include C++ library info
|
||||
if setup.lib:
|
||||
cpp_json = setup.lib.export_registry_json()
|
||||
try:
|
||||
export_data["cpp_registry"] = json.loads(cpp_json)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
json_str = json.dumps(export_data, indent=2)
|
||||
|
||||
with open(args.output, "w") as f:
|
||||
f.write(json_str)
|
||||
print(f" Saved to: {args.output}")
|
||||
|
||||
# Preview
|
||||
print("\nStep 4: Preview")
|
||||
print("-" * 60)
|
||||
print(json_str[:500] + ("..." if len(json_str) > 500 else ""))
|
||||
print("-" * 60)
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("JSON Export complete!")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
513
dispatcher/examples/gemm/python/07_stress_test.py
Normal file
513
dispatcher/examples/gemm/python/07_stress_test.py
Normal file
@@ -0,0 +1,513 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 07: Stress Test - Multiple Kernels with Validation
|
||||
|
||||
Consolidated stress test that:
|
||||
1. Declares multiple kernel configurations (various tiles, pipelines, layouts)
|
||||
2. Prints all registered kernels with details
|
||||
3. Validates each kernel against NumPy reference
|
||||
4. Optional benchmarking mode
|
||||
|
||||
This tests:
|
||||
- Multiple tile sizes (64x64, 128x128, 256x256)
|
||||
- Multiple pipelines (compv3, compv4)
|
||||
- Multiple data types (fp16, bf16)
|
||||
- Different schedulers (intrawave, interwave)
|
||||
|
||||
Complexity: ★★★★☆
|
||||
|
||||
Usage:
|
||||
python3 07_stress_test.py
|
||||
python3 07_stress_test.py --help
|
||||
python3 07_stress_test.py --num-kernels 10
|
||||
python3 07_stress_test.py --benchmark
|
||||
python3 07_stress_test.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
Validator,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSpec:
|
||||
"""A kernel specification for testing"""
|
||||
|
||||
name: str
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
wave_m: int = 2
|
||||
wave_n: int = 2
|
||||
wave_k: int = 1
|
||||
warp_m: int = 32
|
||||
warp_n: int = 32
|
||||
warp_k: int = 16
|
||||
pipeline: str = "compv3"
|
||||
scheduler: str = "intrawave"
|
||||
layout: str = "rcr"
|
||||
|
||||
def to_config(self, dtype: str, arch: str) -> KernelConfig:
|
||||
"""Convert to KernelConfig"""
|
||||
# Adjust warp tiles for smaller tiles
|
||||
warp_m = min(self.warp_m, self.tile_m // self.wave_m)
|
||||
warp_n = min(self.warp_n, self.tile_n // self.wave_n)
|
||||
warp_k = self.warp_k
|
||||
|
||||
return KernelConfig(
|
||||
dtype_a=dtype,
|
||||
dtype_b=dtype,
|
||||
dtype_c=dtype,
|
||||
dtype_acc="fp32",
|
||||
layout_a={"r": "row", "c": "col"}[self.layout[0]],
|
||||
layout_b={"r": "row", "c": "col"}[self.layout[1]],
|
||||
layout_c={"r": "row", "c": "col"}[self.layout[2]],
|
||||
tile_m=self.tile_m,
|
||||
tile_n=self.tile_n,
|
||||
tile_k=self.tile_k,
|
||||
wave_m=self.wave_m,
|
||||
wave_n=self.wave_n,
|
||||
wave_k=self.wave_k,
|
||||
warp_m=warp_m,
|
||||
warp_n=warp_n,
|
||||
warp_k=warp_k,
|
||||
pipeline=self.pipeline,
|
||||
scheduler=self.scheduler,
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
# Define stress test kernel configurations
|
||||
KERNEL_SPECS = [
|
||||
# Small tiles - compv3
|
||||
KernelSpec(
|
||||
"small_compv3",
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=16,
|
||||
warp_n=16,
|
||||
warp_k=32,
|
||||
pipeline="compv3",
|
||||
),
|
||||
KernelSpec(
|
||||
"small_compv4",
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=16,
|
||||
warp_n=16,
|
||||
warp_k=32,
|
||||
pipeline="compv4",
|
||||
),
|
||||
# Medium tiles
|
||||
KernelSpec(
|
||||
"medium_compv3",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_compv4",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv4",
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_k64",
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
# Rectangular tiles
|
||||
KernelSpec(
|
||||
"rect_64x128",
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
KernelSpec(
|
||||
"rect_128x64",
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
# Different schedulers
|
||||
KernelSpec(
|
||||
"interwave",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
scheduler="interwave",
|
||||
),
|
||||
# Large tiles
|
||||
KernelSpec(
|
||||
"large_compv3",
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
KernelSpec(
|
||||
"large_compv4",
|
||||
256,
|
||||
128,
|
||||
64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv4",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def print_kernel_summary(specs: List[KernelSpec], dtype: str):
|
||||
"""Print a summary table of all kernel specs"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)")
|
||||
print("=" * 80)
|
||||
print(
|
||||
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Wave':<10} {'Warp':<12} {'Pipeline':<10} {'Sched':<10}"
|
||||
)
|
||||
print(" " + "-" * 78)
|
||||
|
||||
for i, spec in enumerate(specs, 1):
|
||||
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
|
||||
wave = f"{spec.wave_m}x{spec.wave_n}x{spec.wave_k}"
|
||||
warp = f"{spec.warp_m}x{spec.warp_n}x{spec.warp_k}"
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {wave:<10} {warp:<12} {spec.pipeline:<10} {spec.scheduler:<10}"
|
||||
)
|
||||
|
||||
print(" " + "-" * 78)
|
||||
print(f" Data type: {dtype}\n")
|
||||
|
||||
|
||||
def validate_kernel(
|
||||
spec: KernelSpec,
|
||||
dtype: str,
|
||||
arch: str,
|
||||
size: int,
|
||||
validator: Validator,
|
||||
kernel_index: int = 0,
|
||||
verbose: bool = False,
|
||||
) -> Tuple[bool, float, str]:
|
||||
"""
|
||||
Validate a single kernel configuration.
|
||||
Returns: (passed, max_error, message)
|
||||
"""
|
||||
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
# Create config
|
||||
config = spec.to_config(dtype, arch)
|
||||
|
||||
# Setup dispatcher
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"stress_{spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
if not setup.success:
|
||||
return False, 0.0, f"Setup failed: {setup.error}"
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
M, N, K = size, size, size
|
||||
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
cleanup_gemm()
|
||||
return False, 0.0, f"Size {M}x{N}x{K} not supported"
|
||||
|
||||
# Use different seed per kernel to get unique test data
|
||||
# This ensures each kernel is tested with different matrices
|
||||
np.random.seed(42 + kernel_index * 1000)
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
# Run GPU GEMM
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if not result.success:
|
||||
cleanup_gemm()
|
||||
return False, 0.0, "GPU execution failed"
|
||||
|
||||
# Validate against NumPy
|
||||
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
|
||||
is_valid, max_err, _ = validator.check(result.output, C_ref)
|
||||
|
||||
cleanup_gemm()
|
||||
|
||||
return is_valid, max_err, f"{result.time_ms:.2f}ms, {result.tflops:.1f} TFLOPS"
|
||||
|
||||
|
||||
def benchmark_kernel(
|
||||
spec: KernelSpec,
|
||||
dtype: str,
|
||||
arch: str,
|
||||
size: int,
|
||||
warmup: int = 3,
|
||||
iterations: int = 10,
|
||||
) -> Tuple[bool, float, float]:
|
||||
"""
|
||||
Benchmark a kernel configuration.
|
||||
Returns: (success, avg_time_ms, tflops)
|
||||
"""
|
||||
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
config = spec.to_config(dtype, arch)
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"bench_{spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
if not setup.success:
|
||||
return False, 0.0, 0.0
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
M, N, K = size, size, size
|
||||
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
cleanup_gemm()
|
||||
return False, 0.0, 0.0
|
||||
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup):
|
||||
dispatcher.run(A, B, M, N, K)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(iterations):
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
if result.success:
|
||||
times.append(result.time_ms)
|
||||
|
||||
cleanup_gemm()
|
||||
|
||||
if not times:
|
||||
return False, 0.0, 0.0
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12
|
||||
|
||||
return True, avg_time, tflops
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM Stress Test - Multiple kernels with validation",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 07_stress_test.py # Test all kernels
|
||||
python3 07_stress_test.py --num-kernels 5 # Test first 5 kernels
|
||||
python3 07_stress_test.py --benchmark # Include benchmarks
|
||||
python3 07_stress_test.py --dtype bf16 # Test BF16
|
||||
python3 07_stress_test.py --size 2048 # Use 2048x2048 matrices
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-kernels",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of kernels to test (0 = all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Problem size MxNxK (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark",
|
||||
action="store_true",
|
||||
help="Include benchmark timing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rtol",
|
||||
type=float,
|
||||
default=1e-2,
|
||||
help="Relative tolerance (default: 1e-2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atol",
|
||||
type=float,
|
||||
default=1e-2,
|
||||
help="Absolute tolerance (default: 1e-2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
default="gfx942",
|
||||
help="Target architecture (default: gfx942)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 80)
|
||||
print("Example 07: GEMM Stress Test - Multiple Kernels")
|
||||
print("=" * 80)
|
||||
|
||||
# Select kernels to test
|
||||
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
|
||||
|
||||
# Print kernel summary
|
||||
print_kernel_summary(specs, args.dtype)
|
||||
|
||||
# Run validation
|
||||
print("\n" + "=" * 80)
|
||||
print(" VALIDATION RESULTS")
|
||||
print("=" * 80)
|
||||
|
||||
validator = Validator(rtol=args.rtol, atol=args.atol)
|
||||
|
||||
if args.benchmark:
|
||||
print(
|
||||
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Time':>10} {'TFLOPS':>8} {'Status':<8}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Info':<25} {'Status':<8}"
|
||||
)
|
||||
print(" " + "-" * 78)
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
|
||||
for i, spec in enumerate(specs, 1):
|
||||
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
|
||||
|
||||
try:
|
||||
is_valid, max_err, info = validate_kernel(
|
||||
spec, args.dtype, args.arch, args.size, validator, kernel_index=i
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
status = "PASS"
|
||||
passed += 1
|
||||
else:
|
||||
status = "FAIL"
|
||||
failed += 1
|
||||
|
||||
if args.benchmark:
|
||||
success, avg_time, tflops = benchmark_kernel(
|
||||
spec, args.dtype, args.arch, args.size
|
||||
)
|
||||
if success:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {avg_time:>9.2f}ms {tflops:>7.1f} {status:<8}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {'N/A':>10} {'N/A':>8} {status:<8}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {info:<25} {status:<8}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
skipped += 1
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {'N/A':>10} {str(e)[:25]:<25} {'SKIP':<8}"
|
||||
)
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 80)
|
||||
print(" SUMMARY")
|
||||
print("=" * 80)
|
||||
total = passed + failed + skipped
|
||||
print(f"\n Results: {passed}/{total} passed, {failed} failed, {skipped} skipped")
|
||||
print(f" Settings: dtype={args.dtype}, size={args.size}x{args.size}x{args.size}")
|
||||
print(f" Tolerance: rtol={args.rtol}, atol={args.atol}")
|
||||
print(f" Architecture: {args.arch}")
|
||||
|
||||
if failed == 0 and skipped == 0:
|
||||
print("\n *** ALL KERNELS PASSED ***")
|
||||
elif failed > 0:
|
||||
print(f"\n *** {failed} KERNELS FAILED ***")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
718
dispatcher/examples/gemm/python/08_heuristics.py
Normal file
718
dispatcher/examples/gemm/python/08_heuristics.py
Normal file
@@ -0,0 +1,718 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 08: Custom Heuristics
|
||||
|
||||
Demonstrates custom kernel selection heuristics based on problem characteristics.
|
||||
|
||||
This example shows how to:
|
||||
1. Define multiple kernel configurations for different workloads
|
||||
2. Implement custom heuristics to select the best kernel
|
||||
3. Test heuristic selection across different problem sizes
|
||||
|
||||
Heuristic strategies:
|
||||
- Size-based: Small tiles for small problems, large tiles for large problems
|
||||
- Compute-bound: Maximize compute utilization for large matrices
|
||||
- Memory-bound: Optimize memory access for bandwidth-limited cases
|
||||
- Latency-focused: Minimize kernel launch overhead for small problems
|
||||
|
||||
Complexity: ★★★★☆
|
||||
|
||||
Usage:
|
||||
python3 08_heuristics.py
|
||||
python3 08_heuristics.py --help
|
||||
python3 08_heuristics.py --strategy compute
|
||||
python3 08_heuristics.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kernel Specifications
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSpec:
|
||||
"""Kernel specification with metadata for heuristic selection"""
|
||||
|
||||
name: str
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
pipeline: str = "compv3"
|
||||
scheduler: str = "intrawave"
|
||||
# Metadata for heuristics
|
||||
category: str = "balanced" # small, balanced, large, compute, memory
|
||||
min_problem_size: int = 0
|
||||
max_problem_size: int = float("inf")
|
||||
|
||||
|
||||
# Define kernel pool for heuristic selection (20+ kernels)
|
||||
KERNEL_POOL = [
|
||||
# ==========================================================================
|
||||
# SMALL TILES - Low latency, good for small problems
|
||||
# ==========================================================================
|
||||
KernelSpec(
|
||||
"small_64x64_k32",
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="small",
|
||||
max_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"small_64x64_k64",
|
||||
64,
|
||||
64,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="small",
|
||||
max_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"small_64x64_v4",
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="small",
|
||||
max_problem_size=256 * 256,
|
||||
),
|
||||
# ==========================================================================
|
||||
# MEDIUM TILES - Balanced performance
|
||||
# ==========================================================================
|
||||
KernelSpec(
|
||||
"medium_128x128_k32",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=128 * 128,
|
||||
max_problem_size=2048 * 2048,
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_128x128_k64",
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_128x128_k128",
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_128x128_v4_k32",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_128x128_v4_k64",
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
# Rectangular medium tiles
|
||||
KernelSpec(
|
||||
"rect_64x128_k32",
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=128 * 128,
|
||||
),
|
||||
KernelSpec(
|
||||
"rect_128x64_k32",
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=128 * 128,
|
||||
),
|
||||
KernelSpec(
|
||||
"rect_64x128_k64",
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"rect_128x64_k64",
|
||||
128,
|
||||
64,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
# ==========================================================================
|
||||
# LARGE TILES - High throughput for large problems
|
||||
# ==========================================================================
|
||||
KernelSpec(
|
||||
"large_256x128_k32",
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="large",
|
||||
min_problem_size=512 * 512,
|
||||
),
|
||||
KernelSpec(
|
||||
"large_256x128_k64",
|
||||
256,
|
||||
128,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="large",
|
||||
min_problem_size=512 * 512,
|
||||
),
|
||||
KernelSpec(
|
||||
"large_128x256_k32",
|
||||
128,
|
||||
256,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="large",
|
||||
min_problem_size=512 * 512,
|
||||
),
|
||||
KernelSpec(
|
||||
"large_128x256_k64",
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="large",
|
||||
min_problem_size=512 * 512,
|
||||
),
|
||||
KernelSpec(
|
||||
"large_256x256_k32",
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="large",
|
||||
min_problem_size=1024 * 1024,
|
||||
),
|
||||
KernelSpec(
|
||||
"large_256x256_k64",
|
||||
256,
|
||||
256,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="large",
|
||||
min_problem_size=1024 * 1024,
|
||||
),
|
||||
# ==========================================================================
|
||||
# COMPUTE-OPTIMIZED - compv4 pipeline for compute-bound workloads
|
||||
# ==========================================================================
|
||||
KernelSpec(
|
||||
"compute_128x128_v4_k32",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="compute",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"compute_128x128_v4_k64",
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="compute",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"compute_256x128_v4",
|
||||
256,
|
||||
128,
|
||||
64,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="compute",
|
||||
min_problem_size=512 * 512,
|
||||
),
|
||||
KernelSpec(
|
||||
"compute_256x256_v4",
|
||||
256,
|
||||
256,
|
||||
64,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="compute",
|
||||
min_problem_size=1024 * 1024,
|
||||
),
|
||||
# ==========================================================================
|
||||
# MEMORY-OPTIMIZED - Good cache utilization for memory-bound workloads
|
||||
# ==========================================================================
|
||||
KernelSpec(
|
||||
"memory_128x128_k16",
|
||||
128,
|
||||
128,
|
||||
16,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="memory",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"memory_64x128_k16",
|
||||
64,
|
||||
128,
|
||||
16,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="memory",
|
||||
min_problem_size=128 * 128,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
|
||||
"""Create KernelConfig from spec"""
|
||||
warp_m = 16 if spec.tile_m <= 64 else 32
|
||||
warp_n = 16 if spec.tile_n <= 64 else 32
|
||||
|
||||
return KernelConfig(
|
||||
dtype_a=dtype,
|
||||
dtype_b=dtype,
|
||||
dtype_c=dtype,
|
||||
dtype_acc="fp32",
|
||||
layout_a="row",
|
||||
layout_b="col",
|
||||
layout_c="row",
|
||||
tile_m=spec.tile_m,
|
||||
tile_n=spec.tile_n,
|
||||
tile_k=spec.tile_k,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
warp_m=warp_m,
|
||||
warp_n=warp_n,
|
||||
warp_k=16,
|
||||
pipeline=spec.pipeline,
|
||||
scheduler=spec.scheduler,
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Heuristic Strategies
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class HeuristicStrategy(Enum):
|
||||
SIZE_BASED = "size"
|
||||
COMPUTE_BOUND = "compute"
|
||||
MEMORY_BOUND = "memory"
|
||||
LATENCY_FOCUSED = "latency"
|
||||
|
||||
|
||||
def size_based_heuristic(
|
||||
M: int, N: int, K: int, kernels: List[KernelSpec]
|
||||
) -> KernelSpec:
|
||||
"""
|
||||
Select kernel based on problem size.
|
||||
- Small problems: Use small tiles for low latency
|
||||
- Medium problems: Use balanced tiles
|
||||
- Large problems: Use large tiles for high throughput
|
||||
|
||||
Also considers K dimension for tile_k selection.
|
||||
"""
|
||||
total_elements = M * N
|
||||
|
||||
# Filter by problem size constraints
|
||||
candidates = [
|
||||
k for k in kernels if k.min_problem_size <= total_elements <= k.max_problem_size
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
candidates = kernels # Fall back to all kernels
|
||||
|
||||
# Determine target category based on problem size
|
||||
if total_elements < 256 * 256:
|
||||
target_category = "small"
|
||||
elif total_elements < 1024 * 1024:
|
||||
target_category = "balanced"
|
||||
else:
|
||||
target_category = "large"
|
||||
|
||||
# Filter by category if possible
|
||||
category_candidates = [k for k in candidates if k.category == target_category]
|
||||
if category_candidates:
|
||||
candidates = category_candidates
|
||||
|
||||
# Select best tile_k based on K dimension
|
||||
# Prefer tile_k that divides K well
|
||||
def tile_k_score(k):
|
||||
if K % k.tile_k == 0:
|
||||
return 0 # Perfect division
|
||||
return K % k.tile_k # Remainder (lower is better)
|
||||
|
||||
# Sort by tile_k fit, then by tile size
|
||||
candidates.sort(key=lambda k: (tile_k_score(k), -k.tile_m * k.tile_n))
|
||||
|
||||
return candidates[0]
|
||||
|
||||
|
||||
def compute_bound_heuristic(
|
||||
M: int, N: int, K: int, kernels: List[KernelSpec]
|
||||
) -> KernelSpec:
|
||||
"""
|
||||
Select kernel optimized for compute-bound workloads.
|
||||
Prefers compv4 pipeline and larger tiles.
|
||||
Selects based on problem size to maximize compute utilization.
|
||||
"""
|
||||
total_elements = M * N
|
||||
|
||||
# Prefer compute category kernels
|
||||
compute_kernels = [k for k in kernels if k.category == "compute"]
|
||||
|
||||
if not compute_kernels:
|
||||
# Fall back to compv4 kernels
|
||||
compute_kernels = [k for k in kernels if k.pipeline == "compv4"]
|
||||
|
||||
if not compute_kernels:
|
||||
compute_kernels = kernels
|
||||
|
||||
# Filter by problem size
|
||||
valid = [k for k in compute_kernels if k.min_problem_size <= total_elements]
|
||||
if valid:
|
||||
compute_kernels = valid
|
||||
|
||||
# For large problems, prefer larger tiles
|
||||
if total_elements >= 1024 * 1024:
|
||||
return max(compute_kernels, key=lambda k: k.tile_m * k.tile_n * k.tile_k)
|
||||
else:
|
||||
# For smaller problems, prefer medium tiles
|
||||
return min(
|
||||
compute_kernels, key=lambda k: abs(k.tile_m - 128) + abs(k.tile_n - 128)
|
||||
)
|
||||
|
||||
|
||||
def memory_bound_heuristic(
|
||||
M: int, N: int, K: int, kernels: List[KernelSpec]
|
||||
) -> KernelSpec:
|
||||
"""
|
||||
Select kernel optimized for memory-bound workloads.
|
||||
Prefers smaller tile_k for better memory access patterns.
|
||||
"""
|
||||
# Prefer memory category kernels first
|
||||
memory_kernels = [k for k in kernels if k.category == "memory"]
|
||||
if memory_kernels:
|
||||
# Select based on problem size
|
||||
total = M * N
|
||||
if total < 512 * 512:
|
||||
return min(memory_kernels, key=lambda k: k.tile_m * k.tile_n)
|
||||
return max(memory_kernels, key=lambda k: k.tile_m * k.tile_n)
|
||||
|
||||
# Fall back to balanced with smaller tile_k
|
||||
balanced = [k for k in kernels if k.category == "balanced"]
|
||||
if balanced:
|
||||
# Prefer smaller tile_k for memory-bound
|
||||
return min(balanced, key=lambda k: k.tile_k)
|
||||
|
||||
# Fall back to medium-sized tile with small tile_k
|
||||
return min(
|
||||
kernels, key=lambda k: (k.tile_k, abs(k.tile_m - 128) + abs(k.tile_n - 128))
|
||||
)
|
||||
|
||||
|
||||
def latency_focused_heuristic(
|
||||
M: int, N: int, K: int, kernels: List[KernelSpec]
|
||||
) -> KernelSpec:
|
||||
"""
|
||||
Select kernel optimized for low latency.
|
||||
Prefers smaller tiles and compv4 for faster execution.
|
||||
"""
|
||||
# Prefer small category
|
||||
small_kernels = [k for k in kernels if k.category == "small"]
|
||||
|
||||
if small_kernels:
|
||||
# Among small kernels, prefer compv4 for lower latency
|
||||
v4_small = [k for k in small_kernels if k.pipeline == "compv4"]
|
||||
if v4_small:
|
||||
return v4_small[0]
|
||||
return small_kernels[0]
|
||||
|
||||
# Fall back to smallest tile with compv4 if available
|
||||
all_v4 = [k for k in kernels if k.pipeline == "compv4"]
|
||||
if all_v4:
|
||||
return min(all_v4, key=lambda k: k.tile_m * k.tile_n)
|
||||
|
||||
# Fall back to smallest tile
|
||||
return min(kernels, key=lambda k: k.tile_m * k.tile_n)
|
||||
|
||||
|
||||
HEURISTICS = {
|
||||
HeuristicStrategy.SIZE_BASED: size_based_heuristic,
|
||||
HeuristicStrategy.COMPUTE_BOUND: compute_bound_heuristic,
|
||||
HeuristicStrategy.MEMORY_BOUND: memory_bound_heuristic,
|
||||
HeuristicStrategy.LATENCY_FOCUSED: latency_focused_heuristic,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def print_kernel_pool(kernels: List[KernelSpec]):
|
||||
"""Print available kernels"""
|
||||
print("\n" + "=" * 75)
|
||||
print(" KERNEL POOL")
|
||||
print("=" * 75)
|
||||
print(f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Pipeline':<10} {'Category':<12}")
|
||||
print(" " + "-" * 73)
|
||||
|
||||
for i, k in enumerate(kernels, 1):
|
||||
tile = f"{k.tile_m}x{k.tile_n}x{k.tile_k}"
|
||||
print(f" {i:<3} {k.name:<22} {tile:<14} {k.pipeline:<10} {k.category:<12}")
|
||||
|
||||
print(" " + "-" * 73)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Custom Heuristics Example - intelligent kernel selection",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 08_heuristics.py # Default size-based heuristic
|
||||
python3 08_heuristics.py --strategy compute # Compute-bound heuristic
|
||||
python3 08_heuristics.py --strategy memory # Memory-bound heuristic
|
||||
python3 08_heuristics.py --strategy latency # Latency-focused heuristic
|
||||
python3 08_heuristics.py --dtype bf16 # BF16 mode
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
default="size",
|
||||
choices=["size", "compute", "memory", "latency"],
|
||||
help="Heuristic strategy (default: size)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
default="gfx942",
|
||||
help="Target architecture (default: gfx942)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 75)
|
||||
print("Example 08: Custom Heuristics")
|
||||
print("=" * 75)
|
||||
|
||||
# Map strategy string to enum
|
||||
strategy_map = {
|
||||
"size": HeuristicStrategy.SIZE_BASED,
|
||||
"compute": HeuristicStrategy.COMPUTE_BOUND,
|
||||
"memory": HeuristicStrategy.MEMORY_BOUND,
|
||||
"latency": HeuristicStrategy.LATENCY_FOCUSED,
|
||||
}
|
||||
strategy = strategy_map[args.strategy]
|
||||
heuristic_fn = HEURISTICS[strategy]
|
||||
|
||||
print(f"\n Strategy: {strategy.value}")
|
||||
print(f" Data type: {args.dtype}")
|
||||
|
||||
# Print kernel pool
|
||||
print_kernel_pool(KERNEL_POOL)
|
||||
|
||||
# =========================================================================
|
||||
# Test heuristic selection across different problem sizes
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 75)
|
||||
print(" HEURISTIC SELECTION TEST")
|
||||
print("=" * 75)
|
||||
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
test_sizes = [
|
||||
(128, 128, 64), # Small
|
||||
(256, 256, 128), # Small-medium
|
||||
(512, 512, 256), # Medium
|
||||
(1024, 1024, 512), # Medium-large
|
||||
(2048, 2048, 1024), # Large
|
||||
]
|
||||
|
||||
print(
|
||||
f"\n {'Size':<20} {'Selected Kernel':<25} {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}"
|
||||
)
|
||||
print(" " + "-" * 78)
|
||||
|
||||
results = []
|
||||
|
||||
for M, N, K in test_sizes:
|
||||
# Use heuristic to select kernel
|
||||
selected_spec = heuristic_fn(M, N, K, KERNEL_POOL)
|
||||
|
||||
# Create config and setup
|
||||
config = create_kernel_config(selected_spec, args.dtype, args.arch)
|
||||
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"heuristic_{selected_spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
size_str = f"{M}x{N}x{K}"
|
||||
|
||||
if not setup.success:
|
||||
print(
|
||||
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
||||
)
|
||||
results.append((size_str, selected_spec.name, False, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
print(
|
||||
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'SKIP':<8}"
|
||||
)
|
||||
results.append((size_str, selected_spec.name, False, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
# Run GEMM
|
||||
np.random.seed(42)
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if not result.success:
|
||||
print(
|
||||
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
||||
)
|
||||
results.append((size_str, selected_spec.name, False, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
# Validate
|
||||
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
|
||||
max_err = np.max(np.abs(result.output - C_ref))
|
||||
passed = max_err < 1e-2
|
||||
|
||||
status = "PASS" if passed else "FAIL"
|
||||
print(
|
||||
f" {size_str:<20} {selected_spec.name:<25} {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}"
|
||||
)
|
||||
results.append(
|
||||
(size_str, selected_spec.name, passed, result.time_ms, result.tflops)
|
||||
)
|
||||
|
||||
cleanup_gemm()
|
||||
|
||||
# =========================================================================
|
||||
# Summary
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 75)
|
||||
print(" SUMMARY")
|
||||
print("=" * 75)
|
||||
|
||||
passed = sum(1 for r in results if r[2])
|
||||
failed = len(results) - passed
|
||||
|
||||
print(f"\n Strategy: {strategy.value}")
|
||||
print(f" Results: {passed}/{len(results)} tests passed")
|
||||
|
||||
# Show kernel selection distribution
|
||||
kernel_usage = {}
|
||||
for r in results:
|
||||
kernel_usage[r[1]] = kernel_usage.get(r[1], 0) + 1
|
||||
|
||||
print("\n Kernel Selection Distribution:")
|
||||
for kernel, count in sorted(kernel_usage.items(), key=lambda x: -x[1]):
|
||||
print(f" {kernel}: {count} times")
|
||||
|
||||
if results:
|
||||
valid_results = [r for r in results if r[2]]
|
||||
if valid_results:
|
||||
avg_tflops = sum(r[4] for r in valid_results) / len(valid_results)
|
||||
print(f"\n Average TFLOPS: {avg_tflops:.2f}")
|
||||
|
||||
if failed == 0:
|
||||
print("\n *** ALL TESTS PASSED ***")
|
||||
else:
|
||||
print(f"\n *** {failed} TESTS FAILED ***")
|
||||
|
||||
print("=" * 75)
|
||||
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
220
dispatcher/examples/gemm/python/09_multi_registry.py
Normal file
220
dispatcher/examples/gemm/python/09_multi_registry.py
Normal file
@@ -0,0 +1,220 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 09: Multiple Registries
|
||||
|
||||
Demonstrates multiple registries for different optimization targets.
|
||||
|
||||
Complexity: ★★★★★
|
||||
|
||||
Usage:
|
||||
python3 09_multi_registry.py
|
||||
python3 09_multi_registry.py --help
|
||||
python3 09_multi_registry.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
Registry,
|
||||
Dispatcher,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Multiple Registries Example - optimization-specific registries",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 09_multi_registry.py # Default FP16
|
||||
python3 09_multi_registry.py --dtype bf16 # BF16 mode
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 09: Multiple Registries")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Setup base dispatcher
|
||||
# =========================================================================
|
||||
print("\nStep 1: Setup Base Dispatcher")
|
||||
|
||||
base_config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(base_config, registry_name="base", verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
lib = setup.lib
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Define configs for different optimization targets
|
||||
# =========================================================================
|
||||
print("\nStep 2: Define Optimization Targets")
|
||||
|
||||
compute_config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=256,
|
||||
tile_n=256,
|
||||
tile_k=64,
|
||||
wave_m=4,
|
||||
wave_n=4,
|
||||
pipeline="compv4",
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
memory_config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
pipeline="compv4",
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
latency_config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=64,
|
||||
tile_n=64,
|
||||
tile_k=32,
|
||||
wave_m=1,
|
||||
wave_n=1,
|
||||
pipeline="compv3",
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
print(f" Compute: {compute_config.tile_str} (large matrices)")
|
||||
print(f" Memory: {memory_config.tile_str} (medium matrices)")
|
||||
print(f" Latency: {latency_config.tile_str} (small matrices)")
|
||||
|
||||
# =========================================================================
|
||||
# Step 3: Create registries
|
||||
# =========================================================================
|
||||
print("\nStep 3: Create Registries")
|
||||
|
||||
compute_registry = Registry(name="compute", lib=lib)
|
||||
compute_registry.register_kernel(compute_config)
|
||||
|
||||
memory_registry = Registry(name="memory", lib=lib)
|
||||
memory_registry.register_kernel(memory_config)
|
||||
|
||||
latency_registry = Registry(name="latency", lib=lib)
|
||||
latency_registry.register_kernel(latency_config)
|
||||
|
||||
# =========================================================================
|
||||
# Step 4: Create dispatchers
|
||||
# =========================================================================
|
||||
print("\nStep 4: Create Dispatchers")
|
||||
|
||||
compute_dispatcher = Dispatcher(registry=compute_registry, lib=lib)
|
||||
memory_dispatcher = Dispatcher(registry=memory_registry, lib=lib)
|
||||
latency_dispatcher = Dispatcher(registry=latency_registry, lib=lib)
|
||||
|
||||
print(f" {compute_dispatcher}")
|
||||
print(f" {memory_dispatcher}")
|
||||
print(f" {latency_dispatcher}")
|
||||
|
||||
# =========================================================================
|
||||
# Step 5: Smart dispatcher selection
|
||||
# =========================================================================
|
||||
print("\nStep 5: Smart Dispatcher Selection")
|
||||
|
||||
def select_dispatcher(M: int, N: int, K: int) -> Dispatcher:
|
||||
elements = M * N
|
||||
if elements >= 4096 * 4096:
|
||||
return compute_dispatcher
|
||||
elif elements >= 1024 * 1024:
|
||||
return memory_dispatcher
|
||||
else:
|
||||
return latency_dispatcher
|
||||
|
||||
test_sizes = [
|
||||
(256, 256, 256),
|
||||
(512, 512, 512),
|
||||
(1024, 1024, 1024),
|
||||
(2048, 2048, 2048),
|
||||
(4096, 4096, 4096),
|
||||
]
|
||||
|
||||
print(f"\n {'Size':<20} {'Registry':>10} {'Time (ms)':>12} {'TFLOPS':>10}")
|
||||
print(" " + "-" * 55)
|
||||
|
||||
for M, N, K in test_sizes:
|
||||
dispatcher = select_dispatcher(M, N, K)
|
||||
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
continue
|
||||
|
||||
A = np.random.randn(M, K).astype(np_dtype) * 0.1
|
||||
B = np.random.randn(K, N).astype(np_dtype) * 0.1
|
||||
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if result.success:
|
||||
print(
|
||||
f" {M}x{N}x{K:<10} {dispatcher.registry.name:>10} "
|
||||
f"{result.time_ms:>12.4f} {result.tflops:>10.2f}"
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Multi-Registry Pattern:")
|
||||
print("=" * 60)
|
||||
print(" 1. Define KernelConfig for each optimization target")
|
||||
print(" 2. Create Registry for each target")
|
||||
print(" 3. Register configs to appropriate registries")
|
||||
print(" 4. Create Dispatcher for each registry")
|
||||
print(" 5. Select dispatcher based on problem characteristics")
|
||||
print(" 6. Run GEMM with selected dispatcher")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
260
dispatcher/examples/gemm/python/10_advanced_benchmark.py
Normal file
260
dispatcher/examples/gemm/python/10_advanced_benchmark.py
Normal file
@@ -0,0 +1,260 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 10: Advanced Benchmarking with Full Control
|
||||
|
||||
This example demonstrates all available benchmark parameters:
|
||||
- warmup: Number of warmup iterations (default: 5)
|
||||
- repeat: Number of benchmark iterations (default: 20)
|
||||
- flush_cache: Flush GPU cache between iterations (default: False)
|
||||
- timer: Timer type - "gpu" (default) or "cpu"
|
||||
- init: Initialization method - "random", "linear", "constant"
|
||||
|
||||
Usage:
|
||||
python3 10_advanced_benchmark.py
|
||||
python3 10_advanced_benchmark.py --warmup 10 --repeat 100
|
||||
python3 10_advanced_benchmark.py --init linear
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add paths for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Advanced GEMM benchmarking with full parameter control"
|
||||
)
|
||||
|
||||
# Problem size
|
||||
parser.add_argument("-m", type=int, default=2048, help="M dimension")
|
||||
parser.add_argument("-n", type=int, default=2048, help="N dimension")
|
||||
parser.add_argument("-k", type=int, default=2048, help="K dimension")
|
||||
|
||||
# Benchmark parameters
|
||||
parser.add_argument(
|
||||
"--warmup", type=int, default=5, help="Number of warmup iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat", type=int, default=20, help="Number of benchmark iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flush-cache", action="store_true", help="Flush GPU cache between iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timer", choices=["gpu", "cpu"], default="gpu", help="Timer type (gpu or cpu)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init",
|
||||
choices=["random", "linear", "constant"],
|
||||
default="random",
|
||||
help="Initialization method",
|
||||
)
|
||||
|
||||
# Kernel configuration
|
||||
parser.add_argument("--dtype", default="fp16", help="Data type")
|
||||
parser.add_argument("--pipeline", default="compv4", help="Pipeline type")
|
||||
parser.add_argument("--arch", default="gfx942", help="GPU architecture")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def initialize_matrix(shape, method, dtype):
|
||||
"""Initialize matrix with specified method"""
|
||||
if method == "random":
|
||||
return np.random.randn(*shape).astype(dtype) * 0.5
|
||||
elif method == "linear":
|
||||
total = np.prod(shape)
|
||||
return np.arange(total).reshape(shape).astype(dtype) / total
|
||||
elif method == "constant":
|
||||
return np.ones(shape, dtype=dtype)
|
||||
else:
|
||||
return np.random.randn(*shape).astype(dtype)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 10: Advanced GEMM Benchmarking")
|
||||
print("=" * 70)
|
||||
|
||||
# Show benchmark configuration
|
||||
print("\nBenchmark Configuration:")
|
||||
print(f" Problem Size: {args.m} x {args.n} x {args.k}")
|
||||
print(f" Warmup: {args.warmup} iterations")
|
||||
print(f" Repeat: {args.repeat} iterations")
|
||||
print(f" Flush Cache: {args.flush_cache}")
|
||||
print(f" Timer: {args.timer}")
|
||||
print(f" Init Method: {args.init}")
|
||||
print(f" Data Type: {args.dtype}")
|
||||
print(f" Pipeline: {args.pipeline}")
|
||||
print(f" Architecture: {args.arch}")
|
||||
print()
|
||||
|
||||
# Map dtype
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
# Initialize matrices
|
||||
print("Step 1: Initialize matrices...")
|
||||
A = initialize_matrix((args.m, args.k), args.init, np_dtype)
|
||||
B = initialize_matrix((args.k, args.n), args.init, np_dtype)
|
||||
print(f" A: {A.shape} ({args.init})")
|
||||
print(f" B: {B.shape} ({args.init})")
|
||||
|
||||
# Create kernel config (does not include M/N/K - those are problem size)
|
||||
print("\nStep 2: Create kernel configuration...")
|
||||
kernel_config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
dtype_acc="fp32",
|
||||
layout_a="row",
|
||||
layout_b="col", # B is column-major for optimal performance
|
||||
layout_c="row",
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline=args.pipeline,
|
||||
scheduler="intrawave",
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
print(f" Config: {args.dtype}, tile=128x128x32, {args.pipeline}")
|
||||
|
||||
# Setup dispatcher
|
||||
print("\nStep 3: Setup dispatcher...")
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=kernel_config,
|
||||
registry_name="benchmark_gemm",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
print(f" Library: {setup.lib.path if setup.lib else 'N/A'}")
|
||||
print(f" Kernel: {setup.lib.get_kernel_name() if setup.lib else 'N/A'}")
|
||||
|
||||
# Run benchmark with multiple iterations
|
||||
print("\nStep 4: Run benchmark...")
|
||||
print(f" Running {args.warmup} warmup + {args.repeat} benchmark iterations...")
|
||||
|
||||
# Warmup
|
||||
for _ in range(args.warmup):
|
||||
_ = dispatcher.run(A, B, args.m, args.n, args.k)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(args.repeat):
|
||||
result = dispatcher.run(A, B, args.m, args.n, args.k)
|
||||
if result.success:
|
||||
times.append(result.time_ms)
|
||||
|
||||
if times:
|
||||
avg_time = sum(times) / len(times)
|
||||
min_time = min(times)
|
||||
max_time = max(times)
|
||||
|
||||
# Calculate TFLOPS
|
||||
flops = 2 * args.m * args.n * args.k
|
||||
avg_tflops = (flops / 1e12) / (avg_time / 1000) if avg_time > 0 else 0
|
||||
max_tflops = (flops / 1e12) / (min_time / 1000) if min_time > 0 else 0
|
||||
|
||||
# Calculate bandwidth (C has same dtype as A and B)
|
||||
C_bytes = args.m * args.n * np.dtype(np_dtype).itemsize
|
||||
bandwidth_gb = (
|
||||
(A.nbytes + B.nbytes + C_bytes) / 1e9 / (avg_time / 1000)
|
||||
if avg_time > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
print(f"\n *** BENCHMARK RESULTS ({args.repeat} iterations) ***")
|
||||
print(f" Average Time: {avg_time:.4f} ms")
|
||||
print(f" Min Time: {min_time:.4f} ms")
|
||||
print(f" Max Time: {max_time:.4f} ms")
|
||||
print(f" Avg TFLOPS: {avg_tflops:.2f}")
|
||||
print(f" Peak TFLOPS: {max_tflops:.2f}")
|
||||
print(f" Bandwidth: {bandwidth_gb:.2f} GB/s")
|
||||
else:
|
||||
print(" FAILED: No successful runs")
|
||||
return 1
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("BENCHMARK PARAMETERS REFERENCE")
|
||||
print("=" * 70)
|
||||
print("""
|
||||
Available parameters for GEMM benchmarking:
|
||||
|
||||
--warmup N Number of warmup iterations (discard results)
|
||||
Higher = more stable results, longer run time
|
||||
Default: 5
|
||||
|
||||
--repeat N Number of benchmark iterations
|
||||
Higher = more accurate average, longer run time
|
||||
Default: 20
|
||||
|
||||
--flush-cache Flush GPU L2 cache between iterations
|
||||
Use for memory-bound benchmarks
|
||||
Default: off
|
||||
|
||||
--timer {gpu,cpu} Timer type
|
||||
gpu = HIP events (more accurate for GPU)
|
||||
cpu = std::chrono (includes kernel launch overhead)
|
||||
Default: gpu
|
||||
|
||||
--init METHOD Matrix initialization
|
||||
random = uniform random [-0.5, 0.5]
|
||||
linear = sequential values
|
||||
constant = all ones
|
||||
Default: random
|
||||
|
||||
Note: For C++ examples, these parameters are passed to stream_config:
|
||||
|
||||
ck_tile::stream_config cfg{
|
||||
nullptr, // stream_id
|
||||
true, // time_kernel
|
||||
1, // log_level
|
||||
5, // cold_niters (warmup)
|
||||
20, // nrepeat
|
||||
true, // is_gpu_timer
|
||||
false, // flush_cache
|
||||
1 // rotating_count
|
||||
};
|
||||
""")
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
310
dispatcher/examples/gemm/python/11_json_import.py
Normal file
310
dispatcher/examples/gemm/python/11_json_import.py
Normal file
@@ -0,0 +1,310 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 11: JSON-based Kernel Configuration Import
|
||||
|
||||
Demonstrates loading kernel configurations from JSON files, similar to tile_engine.
|
||||
This enables easy customization of kernel sets without modifying code.
|
||||
|
||||
Key Features:
|
||||
- Load tile configs from JSON (compatible with tile_engine format)
|
||||
- Generate kernel sets from configuration
|
||||
- Use arch_filter validation on loaded configs
|
||||
- Export to C++ DECL_KERNEL_SET format
|
||||
|
||||
Complexity: ★★★☆☆
|
||||
|
||||
Usage:
|
||||
python3 11_json_import.py
|
||||
python3 11_json_import.py --config my_kernels.json
|
||||
python3 11_json_import.py --export-cpp
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Add codegen to path for kernel_config_loader
|
||||
script_dir = Path(__file__).parent.resolve()
|
||||
sys.path.insert(0, str(script_dir.parent.parent.parent / "codegen"))
|
||||
sys.path.insert(0, str(script_dir.parent.parent.parent / "python"))
|
||||
|
||||
from kernel_config_loader import ( # noqa: E402
|
||||
load_kernel_configs,
|
||||
KernelConfig,
|
||||
generate_cpp_kernel_set_declaration,
|
||||
)
|
||||
|
||||
from ctypes_utils import ( # noqa: E402
|
||||
KernelConfig as DispatcherKernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
validate_kernel_config,
|
||||
)
|
||||
|
||||
# Sample JSON configuration (embedded for demonstration)
|
||||
SAMPLE_JSON_CONFIG = {
|
||||
"_comment": "Sample kernel configuration for GEMM",
|
||||
"kernel_set_name": "inference_kernels",
|
||||
"datatype": {"a": "fp16", "b": "fp16", "c": "fp16", "acc": "fp32"},
|
||||
"layout": "rcr",
|
||||
"tile_config": {
|
||||
"tile_m": {"values": [128, 256]},
|
||||
"tile_n": {"values": [128, 256]},
|
||||
"tile_k": {"values": [32]},
|
||||
"warp_m": {"values": [2]},
|
||||
"warp_n": {"values": [2]},
|
||||
"warp_k": {"values": [1]},
|
||||
"warp_tile_m": {"values": [32]},
|
||||
"warp_tile_n": {"values": [32]},
|
||||
"warp_tile_k": {"values": [16]},
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {"values": ["compv4"]},
|
||||
"scheduler": {"values": ["intrawave"]},
|
||||
"epilogue": {"values": ["cshuffle"]},
|
||||
"pad_m": {"values": [False]},
|
||||
"pad_n": {"values": [False]},
|
||||
"pad_k": {"values": [False]},
|
||||
},
|
||||
"gpu_targets": ["gfx942"],
|
||||
}
|
||||
|
||||
|
||||
def print_section(title: str):
|
||||
"""Print a section header"""
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f" {title}")
|
||||
print(f"{'=' * 70}\n")
|
||||
|
||||
|
||||
def convert_to_dispatcher_config(
|
||||
config: KernelConfig, arch: str = "gfx942"
|
||||
) -> DispatcherKernelConfig:
|
||||
"""Convert kernel_config_loader.KernelConfig to dispatcher KernelConfig"""
|
||||
return DispatcherKernelConfig(
|
||||
dtype_a=config.dtype_a,
|
||||
dtype_b=config.dtype_b,
|
||||
dtype_c=config.dtype_c,
|
||||
dtype_acc=config.dtype_acc,
|
||||
tile_m=config.tile.tile_m,
|
||||
tile_n=config.tile.tile_n,
|
||||
tile_k=config.tile.tile_k,
|
||||
wave_m=config.tile.warp_m,
|
||||
wave_n=config.tile.warp_n,
|
||||
wave_k=config.tile.warp_k,
|
||||
warp_m=config.tile.warp_tile_m,
|
||||
warp_n=config.tile.warp_tile_n,
|
||||
warp_k=config.tile.warp_tile_k,
|
||||
pipeline=config.trait.pipeline,
|
||||
scheduler=config.trait.scheduler,
|
||||
epilogue=config.trait.epilogue,
|
||||
pad_m=config.trait.pad_m,
|
||||
pad_n=config.trait.pad_n,
|
||||
pad_k=config.trait.pad_k,
|
||||
gfx_arch=arch,
|
||||
variant=config.variant,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="JSON Kernel Configuration Import Example",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 11_json_import.py # Use embedded sample config
|
||||
python3 11_json_import.py --config my.json # Load from file
|
||||
python3 11_json_import.py --export-cpp # Generate C++ declarations
|
||||
python3 11_json_import.py --validate # Validate configs against arch
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
help="Path to JSON configuration file (uses embedded sample if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-cpp",
|
||||
action="store_true",
|
||||
help="Export kernel set as C++ DECL_KERNEL_SET",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate",
|
||||
action="store_true",
|
||||
help="Validate all configurations against arch filter",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
default="gfx942",
|
||||
help="Target GPU architecture (default: gfx942)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print_section("Example 11: JSON Kernel Configuration Import")
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Load configuration from JSON
|
||||
# =========================================================================
|
||||
print("Step 1: Load Kernel Configuration from JSON")
|
||||
print("-" * 50)
|
||||
|
||||
if args.config:
|
||||
config_path = Path(args.config)
|
||||
if not config_path.exists():
|
||||
print(f" ERROR: Config file not found: {config_path}")
|
||||
return 1
|
||||
print(f" Loading from: {config_path}")
|
||||
config_set = load_kernel_configs(config_path)
|
||||
else:
|
||||
# Use embedded sample config
|
||||
print(" Using embedded sample configuration")
|
||||
# Write to temp file and load
|
||||
temp_path = Path("/tmp/sample_gemm_config.json")
|
||||
with open(temp_path, "w") as f:
|
||||
json.dump(SAMPLE_JSON_CONFIG, f, indent=2)
|
||||
config_set = load_kernel_configs(temp_path)
|
||||
|
||||
print(f"\n Kernel Set Name: {config_set.name}")
|
||||
print(
|
||||
f" Data Types: A={config_set.dtype_a}, B={config_set.dtype_b}, C={config_set.dtype_c}"
|
||||
)
|
||||
print(f" Layout: {config_set.layout}")
|
||||
print(f" GPU Targets: {config_set.gpu_targets}")
|
||||
print(f" Total Configurations: {config_set.config_count()}")
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Display configuration details
|
||||
# =========================================================================
|
||||
print("\nStep 2: Configuration Details")
|
||||
print("-" * 50)
|
||||
|
||||
print("\n Tile Configurations:")
|
||||
print(f" tile_m: {config_set.tile_m_values}")
|
||||
print(f" tile_n: {config_set.tile_n_values}")
|
||||
print(f" tile_k: {config_set.tile_k_values}")
|
||||
print(
|
||||
f" warp (wave): {config_set.warp_m_values}x{config_set.warp_n_values}x{config_set.warp_k_values}"
|
||||
)
|
||||
print(
|
||||
f" warp_tile: {config_set.warp_tile_m_values}x{config_set.warp_tile_n_values}x{config_set.warp_tile_k_values}"
|
||||
)
|
||||
|
||||
print("\n Trait Configurations:")
|
||||
print(f" pipeline: {config_set.pipeline_values}")
|
||||
print(f" scheduler: {config_set.scheduler_values}")
|
||||
print(f" epilogue: {config_set.epilogue_values}")
|
||||
print(
|
||||
f" padding: m={config_set.pad_m_values}, n={config_set.pad_n_values}, k={config_set.pad_k_values}"
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Step 3: Generate and display kernel names
|
||||
# =========================================================================
|
||||
print("\nStep 3: Generated Kernel Names")
|
||||
print("-" * 50)
|
||||
|
||||
configs = list(config_set.generate_configs())
|
||||
for i, config in enumerate(configs[:5]):
|
||||
print(f" {i + 1}. {config.kernel_name()}")
|
||||
if len(configs) > 5:
|
||||
print(f" ... and {len(configs) - 5} more configurations")
|
||||
|
||||
# =========================================================================
|
||||
# Step 4: Validate against arch filter (optional)
|
||||
# =========================================================================
|
||||
if args.validate:
|
||||
print("\nStep 4: Architecture Validation")
|
||||
print("-" * 50)
|
||||
|
||||
valid_count = 0
|
||||
invalid_count = 0
|
||||
|
||||
for config in configs:
|
||||
disp_config = convert_to_dispatcher_config(config, args.arch)
|
||||
result = validate_kernel_config(disp_config)
|
||||
|
||||
if result.is_valid:
|
||||
valid_count += 1
|
||||
else:
|
||||
invalid_count += 1
|
||||
if invalid_count <= 3: # Show first 3 invalid
|
||||
print(f"\n ✗ Invalid: {config.kernel_name()}")
|
||||
for error in result.errors:
|
||||
print(f" Error: {error}")
|
||||
|
||||
print("\n Validation Summary:")
|
||||
print(f" ✓ Valid: {valid_count}")
|
||||
print(f" ✗ Invalid: {invalid_count}")
|
||||
print(f" Total: {len(configs)}")
|
||||
|
||||
# =========================================================================
|
||||
# Step 5: Export to C++ (optional)
|
||||
# =========================================================================
|
||||
if args.export_cpp:
|
||||
print("\nStep 5: C++ Export")
|
||||
print("-" * 50)
|
||||
print("\n // Generated DECL_KERNEL_SET from JSON config:")
|
||||
print(" // " + "=" * 56)
|
||||
cpp_code = generate_cpp_kernel_set_declaration(config_set)
|
||||
for line in cpp_code.split("\n"):
|
||||
print(f" {line}")
|
||||
|
||||
# =========================================================================
|
||||
# Step 6: Use first config with dispatcher (demo)
|
||||
# =========================================================================
|
||||
print("\nStep 6: Dispatcher Integration Demo")
|
||||
print("-" * 50)
|
||||
|
||||
if configs:
|
||||
first_config = configs[0]
|
||||
disp_config = convert_to_dispatcher_config(first_config, args.arch)
|
||||
|
||||
print(
|
||||
f"\n Using first config: {first_config.tile.tile_m}x{first_config.tile.tile_n}x{first_config.tile.tile_k}"
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(
|
||||
disp_config, registry_name="json_import", verbose=False
|
||||
)
|
||||
if setup.success:
|
||||
print(" ✓ Dispatcher setup successful")
|
||||
print(
|
||||
f" Kernel header: {setup.kernel_header.name if setup.kernel_header else 'N/A'}"
|
||||
)
|
||||
else:
|
||||
print(f" ⚠ Dispatcher setup: {setup.error}")
|
||||
print(" (This is expected if kernels aren't generated)")
|
||||
|
||||
# =========================================================================
|
||||
# Summary
|
||||
# =========================================================================
|
||||
print_section("Summary")
|
||||
print(" JSON configuration allows easy kernel set customization:")
|
||||
print(" - Define tile sizes and ranges")
|
||||
print(" - Specify trait combinations (pipeline, scheduler, etc.)")
|
||||
print(" - Target multiple GPU architectures")
|
||||
print(" - Export to C++ DECL_KERNEL_SET for static compilation")
|
||||
print()
|
||||
print(" JSON Format (tile_engine compatible):")
|
||||
print(' {"tile_config": {"tile_m": {"values": [128, 256]}, ...},')
|
||||
print(' "trait_config": {"pipeline": {"values": ["compv4"]}, ...}}')
|
||||
print()
|
||||
print(" Usage:")
|
||||
print(" config_set = load_kernel_configs('my_kernels.json')")
|
||||
print(" for config in config_set.generate_configs():")
|
||||
print(" # Use config for codegen or dispatcher setup")
|
||||
|
||||
cleanup_gemm()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
299
dispatcher/examples/gemm/python/README.md
Normal file
299
dispatcher/examples/gemm/python/README.md
Normal file
@@ -0,0 +1,299 @@
|
||||
# GEMM Python Examples
|
||||
|
||||
CK Tile Dispatcher Python examples for GEMM (General Matrix Multiplication) operations.
|
||||
|
||||
> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Build Library
|
||||
|
||||
```bash
|
||||
cd /path/to/composable_kernel/dispatcher
|
||||
mkdir -p build && cd build
|
||||
|
||||
cmake .. \
|
||||
-DCMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-DBUILD_DISPATCHER_EXAMPLES=ON
|
||||
|
||||
# Build Python library (kernels generated automatically)
|
||||
make dispatcher_gemm_lib -j$(nproc)
|
||||
```
|
||||
|
||||
### Run Examples
|
||||
|
||||
```bash
|
||||
cd /path/to/composable_kernel/dispatcher
|
||||
|
||||
python3 examples/gemm/python/01_basic_gemm.py
|
||||
python3 examples/gemm/python/04_validation.py
|
||||
python3 examples/gemm/python/07_stress_test.py
|
||||
python3 examples/gemm/python/08_heuristics.py
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
| Example | Description |
|
||||
|---------|-------------|
|
||||
| [01_basic_gemm.py](01_basic_gemm.py) | Basic GEMM with multi-kernel support |
|
||||
| [02_batch_gemm.py](02_batch_gemm.py) | Batched GEMM operations |
|
||||
| [03_benchmark.py](03_benchmark.py) | Performance benchmarking |
|
||||
| [04_validation.py](04_validation.py) | CPU reference validation |
|
||||
| [05_numpy_integration.py](05_numpy_integration.py) | NumPy array integration |
|
||||
| [06_json_export.py](06_json_export.py) | Registry JSON export |
|
||||
| [07_stress_test.py](07_stress_test.py) | Multi-kernel stress testing |
|
||||
| [08_heuristics.py](08_heuristics.py) | Heuristic-based kernel selection |
|
||||
| [09_multi_registry.py](09_multi_registry.py) | Multiple registries |
|
||||
| [10_advanced_benchmark.py](10_advanced_benchmark.py) | Advanced benchmark with full control |
|
||||
| [11_json_import.py](11_json_import.py) | Import kernels from JSON |
|
||||
|
||||
## Example Details
|
||||
|
||||
### 01_basic_gemm.py - Basic GEMM
|
||||
Demonstrates the Python API with multi-kernel support:
|
||||
|
||||
```python
|
||||
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
|
||||
|
||||
# Define multiple kernel configurations
|
||||
kernels = [
|
||||
KernelConfig(
|
||||
tile_m=128, tile_n=128, tile_k=32,
|
||||
wave_m=2, wave_n=2, wave_k=1,
|
||||
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
|
||||
pipeline="compv3", scheduler="intrawave"
|
||||
),
|
||||
KernelConfig(
|
||||
tile_m=256, tile_n=256, tile_k=32,
|
||||
wave_m=2, wave_n=2, wave_k=1,
|
||||
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
|
||||
pipeline="compv4", scheduler="intrawave"
|
||||
),
|
||||
]
|
||||
|
||||
# Display configurations
|
||||
print_kernel_config_table(kernels)
|
||||
|
||||
# Set up dispatcher with all kernels
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
|
||||
|
||||
# Run GEMM
|
||||
elapsed_ms = run_gemm(lib, M, N, K, ...)
|
||||
```
|
||||
|
||||
### 02_batch_gemm.py - Batch GEMM
|
||||
Batched matrix multiplication:
|
||||
- Multiple independent GEMM operations
|
||||
- Batch dimension handling
|
||||
|
||||
### 03_benchmark.py - Benchmarking
|
||||
Performance measurement:
|
||||
- GPU timing
|
||||
- TFLOPS calculation
|
||||
- Multiple iterations
|
||||
|
||||
### 04_validation.py - Validation
|
||||
Correctness verification:
|
||||
- NumPy reference implementation
|
||||
- Tolerance-based validation
|
||||
- Error reporting
|
||||
|
||||
### 05_numpy_integration.py - NumPy Integration
|
||||
Seamless NumPy integration:
|
||||
- NumPy arrays to GPU buffers
|
||||
- Results back to NumPy
|
||||
- Automatic type conversion
|
||||
|
||||
### 06_json_export.py - JSON Export
|
||||
Registry serialization for tool integration:
|
||||
- Export kernel configurations
|
||||
- Machine-readable format
|
||||
|
||||
### 07_stress_test.py - Stress Testing
|
||||
Comprehensive multi-kernel stress testing:
|
||||
|
||||
```python
|
||||
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
|
||||
|
||||
# Define 48 unique kernel configurations
|
||||
kernels = [
|
||||
KernelConfig(tile_m=128, tile_n=128, tile_k=32, pipeline="compv3", ...),
|
||||
KernelConfig(tile_m=256, tile_n=256, tile_k=32, pipeline="compv4", ...),
|
||||
KernelConfig(tile_m=128, tile_n=256, tile_k=64, pipeline="compv3", ...),
|
||||
# ... many more configurations
|
||||
]
|
||||
|
||||
# Test each kernel
|
||||
for i, kernel in enumerate(kernels):
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher([kernel])
|
||||
result = run_and_validate(lib, M, N, K, seed=42 + i) # Different seed per kernel
|
||||
print(f"Kernel {i}: {result.max_err:.6e} {'PASS' if result.passed else 'FAIL'}")
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- 48 unique kernel configurations
|
||||
- Various tile sizes, pipelines, and schedulers
|
||||
- Per-kernel validation with unique random seeds
|
||||
- Performance reporting
|
||||
|
||||
### 08_heuristics.py - Heuristic Selection
|
||||
Custom kernel selection based on problem characteristics:
|
||||
|
||||
```python
|
||||
# Define kernel pools for different strategies
|
||||
SMALL_KERNELS = [KernelConfig(tile_m=64, tile_n=64, ...), ...]
|
||||
LARGE_KERNELS = [KernelConfig(tile_m=256, tile_n=256, ...), ...]
|
||||
COMPUTE_KERNELS = [KernelConfig(pipeline="compv4", ...), ...]
|
||||
MEMORY_KERNELS = [KernelConfig(pipeline="compv3", ...), ...]
|
||||
|
||||
# Size-based heuristic
|
||||
def size_based_heuristic(M, N, K):
|
||||
if M * N < 512 * 512:
|
||||
return SMALL_KERNELS
|
||||
else:
|
||||
return LARGE_KERNELS
|
||||
|
||||
# Strategy-based selection
|
||||
def compute_strategy():
|
||||
return COMPUTE_KERNELS # Optimized for compute-bound problems
|
||||
|
||||
def memory_strategy():
|
||||
return MEMORY_KERNELS # Optimized for memory-bound problems
|
||||
|
||||
# Test different strategies
|
||||
for strategy in [size_based_heuristic, compute_strategy, memory_strategy]:
|
||||
kernels = strategy(M, N, K)
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
|
||||
elapsed_ms = run_gemm(lib, M, N, K, ...)
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- 24 kernel configurations across 6 categories
|
||||
- Size-based heuristic (small vs large)
|
||||
- Optimization strategies (compute, memory, latency)
|
||||
- Performance comparison across strategies
|
||||
|
||||
### 09_multi_registry.py - Multiple Registries
|
||||
Separate registries for different workloads:
|
||||
- Compute-optimized registry
|
||||
- Latency-optimized registry
|
||||
- Dynamic registry selection
|
||||
|
||||
### 10_advanced_benchmark.py - Advanced Benchmark
|
||||
Full control over benchmark parameters:
|
||||
- Warmup iterations
|
||||
- Benchmark iterations
|
||||
- Statistical analysis
|
||||
|
||||
### 11_json_import.py - JSON Import
|
||||
Import kernel configurations from JSON:
|
||||
- External configuration files
|
||||
- Dynamic kernel loading
|
||||
|
||||
## Utility Module: ctypes_utils.py
|
||||
|
||||
```python
|
||||
from ctypes_utils import (
|
||||
KernelConfig, # Single kernel configuration
|
||||
setup_gemm_dispatcher, # Set up dispatcher with kernels
|
||||
print_kernel_config_table, # Display kernel configurations
|
||||
Dispatcher, # High-level dispatcher
|
||||
Registry, # Kernel registry
|
||||
Validator, # Validation utilities
|
||||
)
|
||||
```
|
||||
|
||||
### KernelConfig
|
||||
|
||||
```python
|
||||
config = KernelConfig(
|
||||
# Tile sizes
|
||||
tile_m=256, tile_n=256, tile_k=32,
|
||||
# Wave configuration
|
||||
wave_m=2, wave_n=2, wave_k=1,
|
||||
# Warp tile sizes
|
||||
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
|
||||
# Pipeline and scheduler
|
||||
pipeline="compv4", # "compv3" or "compv4"
|
||||
scheduler="intrawave", # "intrawave" or "interwave"
|
||||
# Optional
|
||||
epilogue="default",
|
||||
padding=True,
|
||||
double_buffer=True,
|
||||
)
|
||||
```
|
||||
|
||||
### setup_gemm_dispatcher
|
||||
|
||||
```python
|
||||
# Single kernel
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher(config)
|
||||
|
||||
# Multiple kernels
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher([config1, config2, ...])
|
||||
|
||||
# With auto-rebuild
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher(config, auto_rebuild=True)
|
||||
```
|
||||
|
||||
### print_kernel_config_table
|
||||
|
||||
```python
|
||||
kernels = [config1, config2, config3]
|
||||
print_kernel_config_table(kernels)
|
||||
# Output:
|
||||
# +----+-------+-------+-------+--------+-----------+
|
||||
# | # | Tile | Wave | Warp | Pipe | Scheduler |
|
||||
# +----+-------+-------+-------+--------+-----------+
|
||||
# | 1 | 128x128x32 | 2x2x1 | 32x32x16 | compv3 | intrawave |
|
||||
# | 2 | 256x256x32 | 2x2x1 | 32x32x16 | compv4 | intrawave |
|
||||
# | 3 | 128x256x64 | 2x2x1 | 32x32x16 | compv3 | interwave |
|
||||
# +----+-------+-------+-------+--------+-----------+
|
||||
```
|
||||
|
||||
### GPU Memory Management
|
||||
|
||||
```python
|
||||
import ctypes
|
||||
import numpy as np
|
||||
|
||||
# Load HIP library
|
||||
hip = ctypes.CDLL("libamdhip64.so")
|
||||
|
||||
# Allocate GPU memory
|
||||
gpu_ptr = ctypes.c_void_p()
|
||||
hip.hipMalloc(ctypes.byref(gpu_ptr), size_in_bytes)
|
||||
|
||||
# Copy to GPU (1 = hipMemcpyHostToDevice)
|
||||
hip.hipMemcpy(gpu_ptr, host_array.ctypes.data, size, 1)
|
||||
|
||||
# Copy back (2 = hipMemcpyDeviceToHost)
|
||||
hip.hipMemcpy(host_array.ctypes.data, gpu_ptr, size, 2)
|
||||
|
||||
# Free
|
||||
hip.hipFree(gpu_ptr)
|
||||
```
|
||||
|
||||
## Performance Testing
|
||||
|
||||
Test compilation performance with different kernel counts:
|
||||
|
||||
```bash
|
||||
# Test with 10 kernels (~15s compile time)
|
||||
python3 01_basic_gemm.py --num-kernels 10
|
||||
|
||||
# Test with 20 kernels (~25s compile time)
|
||||
python3 01_basic_gemm.py --num-kernels 20
|
||||
|
||||
# Test with 48 kernels (~50s compile time)
|
||||
python3 01_basic_gemm.py --num-kernels 48
|
||||
```
|
||||
|
||||
Compilation time scales roughly linearly with kernel count.
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [C++ GEMM Examples](../cpp/README.md)
|
||||
- [Python Conv Examples](../../conv/python/README.md)
|
||||
- [Main Dispatcher README](../../../README.md)
|
||||
80
dispatcher/examples/gemm/python/kernels.json
Normal file
80
dispatcher/examples/gemm/python/kernels.json
Normal file
@@ -0,0 +1,80 @@
|
||||
{
|
||||
"registry": "export_demo",
|
||||
"kernel_count": 3,
|
||||
"kernels": [
|
||||
{
|
||||
"tile": "128x128x32",
|
||||
"dtypes": {
|
||||
"A": "fp16",
|
||||
"B": "fp16",
|
||||
"C": "fp16"
|
||||
},
|
||||
"layout": "rcr",
|
||||
"pipeline": "compv4",
|
||||
"target": "gfx942"
|
||||
},
|
||||
{
|
||||
"tile": "256x256x64",
|
||||
"dtypes": {
|
||||
"A": "fp16",
|
||||
"B": "fp16",
|
||||
"C": "fp16"
|
||||
},
|
||||
"layout": "rcr",
|
||||
"pipeline": "compv4",
|
||||
"target": "gfx942"
|
||||
},
|
||||
{
|
||||
"tile": "64x64x32",
|
||||
"dtypes": {
|
||||
"A": "fp16",
|
||||
"B": "fp16",
|
||||
"C": "fp16"
|
||||
},
|
||||
"layout": "rcr",
|
||||
"pipeline": "compv4",
|
||||
"target": "gfx942"
|
||||
}
|
||||
],
|
||||
"cpp_registry": {
|
||||
"metadata": {
|
||||
"timestamp": "Dec 4 2025 06:23:15",
|
||||
"total_kernels": 1,
|
||||
"export_version": "1.0",
|
||||
"dispatcher_version": "1.0.0"
|
||||
},
|
||||
"statistics": {
|
||||
"by_datatype": {},
|
||||
"by_pipeline": {},
|
||||
"by_scheduler": {}
|
||||
},
|
||||
"kernels": [
|
||||
{
|
||||
"identifier": "128x128x32_2x2x1_32x32x16_nopers",
|
||||
"name": "gemm_fp16_rcrr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16",
|
||||
"algorithm": {
|
||||
"tile_shape": {
|
||||
"m": 128,
|
||||
"n": 128,
|
||||
"k": 32
|
||||
},
|
||||
"wave_shape": {
|
||||
"m": 2,
|
||||
"n": 2,
|
||||
"k": 1
|
||||
},
|
||||
"warp_tile_shape": {
|
||||
"m": 32,
|
||||
"n": 32,
|
||||
"k": 16
|
||||
},
|
||||
"block_size": 256,
|
||||
"persistent": false,
|
||||
"double_buffer": true,
|
||||
"preshuffle": false,
|
||||
"transpose_c": false
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user