Files
composable_kernel/dispatcher/examples/gemm/python/09_multi_registry.py
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

221 lines
6.4 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 09: Multiple Registries
Demonstrates multiple registries for different optimization targets.
Complexity: ★★★★★
Usage:
python3 09_multi_registry.py
python3 09_multi_registry.py --help
python3 09_multi_registry.py --dtype bf16
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
Registry,
Dispatcher,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
)
def main():
parser = argparse.ArgumentParser(
description="Multiple Registries Example - optimization-specific registries",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 09_multi_registry.py # Default FP16
python3 09_multi_registry.py --dtype bf16 # BF16 mode
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
)
args = parser.parse_args()
reset_for_example()
print("=" * 60)
print("Example 09: Multiple Registries")
print("=" * 60)
# =========================================================================
# Step 1: Setup base dispatcher
# =========================================================================
print("\nStep 1: Setup Base Dispatcher")
base_config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=128,
tile_n=128,
tile_k=32,
gfx_arch=args.arch,
)
setup = setup_gemm_dispatcher(base_config, registry_name="base", verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
return 1
lib = setup.lib
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
# =========================================================================
# Step 2: Define configs for different optimization targets
# =========================================================================
print("\nStep 2: Define Optimization Targets")
compute_config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=256,
tile_n=256,
tile_k=64,
wave_m=4,
wave_n=4,
pipeline="compv4",
gfx_arch=args.arch,
)
memory_config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=128,
tile_n=128,
tile_k=32,
wave_m=2,
wave_n=2,
pipeline="compv4",
gfx_arch=args.arch,
)
latency_config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=64,
tile_n=64,
tile_k=32,
wave_m=1,
wave_n=1,
pipeline="compv3",
gfx_arch=args.arch,
)
print(f" Compute: {compute_config.tile_str} (large matrices)")
print(f" Memory: {memory_config.tile_str} (medium matrices)")
print(f" Latency: {latency_config.tile_str} (small matrices)")
# =========================================================================
# Step 3: Create registries
# =========================================================================
print("\nStep 3: Create Registries")
compute_registry = Registry(name="compute", lib=lib)
compute_registry.register_kernel(compute_config)
memory_registry = Registry(name="memory", lib=lib)
memory_registry.register_kernel(memory_config)
latency_registry = Registry(name="latency", lib=lib)
latency_registry.register_kernel(latency_config)
# =========================================================================
# Step 4: Create dispatchers
# =========================================================================
print("\nStep 4: Create Dispatchers")
compute_dispatcher = Dispatcher(registry=compute_registry, lib=lib)
memory_dispatcher = Dispatcher(registry=memory_registry, lib=lib)
latency_dispatcher = Dispatcher(registry=latency_registry, lib=lib)
print(f" {compute_dispatcher}")
print(f" {memory_dispatcher}")
print(f" {latency_dispatcher}")
# =========================================================================
# Step 5: Smart dispatcher selection
# =========================================================================
print("\nStep 5: Smart Dispatcher Selection")
def select_dispatcher(M: int, N: int, K: int) -> Dispatcher:
elements = M * N
if elements >= 4096 * 4096:
return compute_dispatcher
elif elements >= 1024 * 1024:
return memory_dispatcher
else:
return latency_dispatcher
test_sizes = [
(256, 256, 256),
(512, 512, 512),
(1024, 1024, 1024),
(2048, 2048, 2048),
(4096, 4096, 4096),
]
print(f"\n {'Size':<20} {'Registry':>10} {'Time (ms)':>12} {'TFLOPS':>10}")
print(" " + "-" * 55)
for M, N, K in test_sizes:
dispatcher = select_dispatcher(M, N, K)
if not dispatcher.is_supported(M, N, K):
continue
A = np.random.randn(M, K).astype(np_dtype) * 0.1
B = np.random.randn(K, N).astype(np_dtype) * 0.1
result = dispatcher.run(A, B, M, N, K)
if result.success:
print(
f" {M}x{N}x{K:<10} {dispatcher.registry.name:>10} "
f"{result.time_ms:>12.4f} {result.tflops:>10.2f}"
)
# Cleanup
cleanup_gemm()
# Summary
print("\n" + "=" * 60)
print("Multi-Registry Pattern:")
print("=" * 60)
print(" 1. Define KernelConfig for each optimization target")
print(" 2. Create Registry for each target")
print(" 3. Register configs to appropriate registries")
print(" 4. Create Dispatcher for each registry")
print(" 5. Select dispatcher based on problem characteristics")
print(" 6. Run GEMM with selected dispatcher")
print("=" * 60)
return 0
if __name__ == "__main__":
sys.exit(main())