mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-23 16:47:40 +00:00
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
221 lines
6.4 KiB
Python
221 lines
6.4 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Example 09: Multiple Registries
|
|
|
|
Demonstrates multiple registries for different optimization targets.
|
|
|
|
Complexity: ★★★★★
|
|
|
|
Usage:
|
|
python3 09_multi_registry.py
|
|
python3 09_multi_registry.py --help
|
|
python3 09_multi_registry.py --dtype bf16
|
|
"""
|
|
|
|
import sys
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
|
import numpy as np
|
|
|
|
from ctypes_utils import (
|
|
KernelConfig,
|
|
Registry,
|
|
Dispatcher,
|
|
setup_gemm_dispatcher,
|
|
cleanup_gemm,
|
|
reset_for_example,
|
|
)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Multiple Registries Example - optimization-specific registries",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
python3 09_multi_registry.py # Default FP16
|
|
python3 09_multi_registry.py --dtype bf16 # BF16 mode
|
|
""",
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
default="fp16",
|
|
choices=["fp16", "bf16", "fp32"],
|
|
help="Data type (default: fp16)",
|
|
)
|
|
parser.add_argument(
|
|
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
reset_for_example()
|
|
|
|
print("=" * 60)
|
|
print("Example 09: Multiple Registries")
|
|
print("=" * 60)
|
|
|
|
# =========================================================================
|
|
# Step 1: Setup base dispatcher
|
|
# =========================================================================
|
|
print("\nStep 1: Setup Base Dispatcher")
|
|
|
|
base_config = KernelConfig(
|
|
dtype_a=args.dtype,
|
|
dtype_b=args.dtype,
|
|
dtype_c=args.dtype,
|
|
tile_m=128,
|
|
tile_n=128,
|
|
tile_k=32,
|
|
gfx_arch=args.arch,
|
|
)
|
|
|
|
setup = setup_gemm_dispatcher(base_config, registry_name="base", verbose=True)
|
|
if not setup.success:
|
|
print(f" ERROR: {setup.error}")
|
|
return 1
|
|
|
|
lib = setup.lib
|
|
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
|
|
|
# =========================================================================
|
|
# Step 2: Define configs for different optimization targets
|
|
# =========================================================================
|
|
print("\nStep 2: Define Optimization Targets")
|
|
|
|
compute_config = KernelConfig(
|
|
dtype_a=args.dtype,
|
|
dtype_b=args.dtype,
|
|
dtype_c=args.dtype,
|
|
tile_m=256,
|
|
tile_n=256,
|
|
tile_k=64,
|
|
wave_m=4,
|
|
wave_n=4,
|
|
pipeline="compv4",
|
|
gfx_arch=args.arch,
|
|
)
|
|
memory_config = KernelConfig(
|
|
dtype_a=args.dtype,
|
|
dtype_b=args.dtype,
|
|
dtype_c=args.dtype,
|
|
tile_m=128,
|
|
tile_n=128,
|
|
tile_k=32,
|
|
wave_m=2,
|
|
wave_n=2,
|
|
pipeline="compv4",
|
|
gfx_arch=args.arch,
|
|
)
|
|
latency_config = KernelConfig(
|
|
dtype_a=args.dtype,
|
|
dtype_b=args.dtype,
|
|
dtype_c=args.dtype,
|
|
tile_m=64,
|
|
tile_n=64,
|
|
tile_k=32,
|
|
wave_m=1,
|
|
wave_n=1,
|
|
pipeline="compv3",
|
|
gfx_arch=args.arch,
|
|
)
|
|
|
|
print(f" Compute: {compute_config.tile_str} (large matrices)")
|
|
print(f" Memory: {memory_config.tile_str} (medium matrices)")
|
|
print(f" Latency: {latency_config.tile_str} (small matrices)")
|
|
|
|
# =========================================================================
|
|
# Step 3: Create registries
|
|
# =========================================================================
|
|
print("\nStep 3: Create Registries")
|
|
|
|
compute_registry = Registry(name="compute", lib=lib)
|
|
compute_registry.register_kernel(compute_config)
|
|
|
|
memory_registry = Registry(name="memory", lib=lib)
|
|
memory_registry.register_kernel(memory_config)
|
|
|
|
latency_registry = Registry(name="latency", lib=lib)
|
|
latency_registry.register_kernel(latency_config)
|
|
|
|
# =========================================================================
|
|
# Step 4: Create dispatchers
|
|
# =========================================================================
|
|
print("\nStep 4: Create Dispatchers")
|
|
|
|
compute_dispatcher = Dispatcher(registry=compute_registry, lib=lib)
|
|
memory_dispatcher = Dispatcher(registry=memory_registry, lib=lib)
|
|
latency_dispatcher = Dispatcher(registry=latency_registry, lib=lib)
|
|
|
|
print(f" {compute_dispatcher}")
|
|
print(f" {memory_dispatcher}")
|
|
print(f" {latency_dispatcher}")
|
|
|
|
# =========================================================================
|
|
# Step 5: Smart dispatcher selection
|
|
# =========================================================================
|
|
print("\nStep 5: Smart Dispatcher Selection")
|
|
|
|
def select_dispatcher(M: int, N: int, K: int) -> Dispatcher:
|
|
elements = M * N
|
|
if elements >= 4096 * 4096:
|
|
return compute_dispatcher
|
|
elif elements >= 1024 * 1024:
|
|
return memory_dispatcher
|
|
else:
|
|
return latency_dispatcher
|
|
|
|
test_sizes = [
|
|
(256, 256, 256),
|
|
(512, 512, 512),
|
|
(1024, 1024, 1024),
|
|
(2048, 2048, 2048),
|
|
(4096, 4096, 4096),
|
|
]
|
|
|
|
print(f"\n {'Size':<20} {'Registry':>10} {'Time (ms)':>12} {'TFLOPS':>10}")
|
|
print(" " + "-" * 55)
|
|
|
|
for M, N, K in test_sizes:
|
|
dispatcher = select_dispatcher(M, N, K)
|
|
|
|
if not dispatcher.is_supported(M, N, K):
|
|
continue
|
|
|
|
A = np.random.randn(M, K).astype(np_dtype) * 0.1
|
|
B = np.random.randn(K, N).astype(np_dtype) * 0.1
|
|
|
|
result = dispatcher.run(A, B, M, N, K)
|
|
|
|
if result.success:
|
|
print(
|
|
f" {M}x{N}x{K:<10} {dispatcher.registry.name:>10} "
|
|
f"{result.time_ms:>12.4f} {result.tflops:>10.2f}"
|
|
)
|
|
|
|
# Cleanup
|
|
cleanup_gemm()
|
|
|
|
# Summary
|
|
print("\n" + "=" * 60)
|
|
print("Multi-Registry Pattern:")
|
|
print("=" * 60)
|
|
print(" 1. Define KernelConfig for each optimization target")
|
|
print(" 2. Create Registry for each target")
|
|
print(" 3. Register configs to appropriate registries")
|
|
print(" 4. Create Dispatcher for each registry")
|
|
print(" 5. Select dispatcher based on problem characteristics")
|
|
print(" 6. Run GEMM with selected dispatcher")
|
|
print("=" * 60)
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|