Files
composable_kernel/dispatcher/examples/gemm/python/08_heuristics.py
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

719 lines
20 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 08: Custom Heuristics
Demonstrates custom kernel selection heuristics based on problem characteristics.
This example shows how to:
1. Define multiple kernel configurations for different workloads
2. Implement custom heuristics to select the best kernel
3. Test heuristic selection across different problem sizes
Heuristic strategies:
- Size-based: Small tiles for small problems, large tiles for large problems
- Compute-bound: Maximize compute utilization for large matrices
- Memory-bound: Optimize memory access for bandwidth-limited cases
- Latency-focused: Minimize kernel launch overhead for small problems
Complexity: ★★★★☆
Usage:
python3 08_heuristics.py
python3 08_heuristics.py --help
python3 08_heuristics.py --strategy compute
python3 08_heuristics.py --dtype bf16
"""
import sys
import argparse
from pathlib import Path
from dataclasses import dataclass
from typing import List
from enum import Enum
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
)
# =============================================================================
# Kernel Specifications
# =============================================================================
@dataclass
class KernelSpec:
"""Kernel specification with metadata for heuristic selection"""
name: str
tile_m: int
tile_n: int
tile_k: int
pipeline: str = "compv3"
scheduler: str = "intrawave"
# Metadata for heuristics
category: str = "balanced" # small, balanced, large, compute, memory
min_problem_size: int = 0
max_problem_size: int = float("inf")
# Define kernel pool for heuristic selection (20+ kernels)
KERNEL_POOL = [
# ==========================================================================
# SMALL TILES - Low latency, good for small problems
# ==========================================================================
KernelSpec(
"small_64x64_k32",
64,
64,
32,
"compv3",
"intrawave",
category="small",
max_problem_size=256 * 256,
),
KernelSpec(
"small_64x64_k64",
64,
64,
64,
"compv3",
"intrawave",
category="small",
max_problem_size=256 * 256,
),
KernelSpec(
"small_64x64_v4",
64,
64,
32,
"compv4",
"intrawave",
category="small",
max_problem_size=256 * 256,
),
# ==========================================================================
# MEDIUM TILES - Balanced performance
# ==========================================================================
KernelSpec(
"medium_128x128_k32",
128,
128,
32,
"compv3",
"intrawave",
category="balanced",
min_problem_size=128 * 128,
max_problem_size=2048 * 2048,
),
KernelSpec(
"medium_128x128_k64",
128,
128,
64,
"compv3",
"intrawave",
category="balanced",
min_problem_size=256 * 256,
),
KernelSpec(
"medium_128x128_k128",
128,
128,
128,
"compv3",
"intrawave",
category="balanced",
min_problem_size=256 * 256,
),
KernelSpec(
"medium_128x128_v4_k32",
128,
128,
32,
"compv4",
"intrawave",
category="balanced",
min_problem_size=256 * 256,
),
KernelSpec(
"medium_128x128_v4_k64",
128,
128,
64,
"compv4",
"intrawave",
category="balanced",
min_problem_size=256 * 256,
),
# Rectangular medium tiles
KernelSpec(
"rect_64x128_k32",
64,
128,
32,
"compv3",
"intrawave",
category="balanced",
min_problem_size=128 * 128,
),
KernelSpec(
"rect_128x64_k32",
128,
64,
32,
"compv3",
"intrawave",
category="balanced",
min_problem_size=128 * 128,
),
KernelSpec(
"rect_64x128_k64",
64,
128,
64,
"compv3",
"intrawave",
category="balanced",
min_problem_size=256 * 256,
),
KernelSpec(
"rect_128x64_k64",
128,
64,
64,
"compv3",
"intrawave",
category="balanced",
min_problem_size=256 * 256,
),
# ==========================================================================
# LARGE TILES - High throughput for large problems
# ==========================================================================
KernelSpec(
"large_256x128_k32",
256,
128,
32,
"compv3",
"intrawave",
category="large",
min_problem_size=512 * 512,
),
KernelSpec(
"large_256x128_k64",
256,
128,
64,
"compv3",
"intrawave",
category="large",
min_problem_size=512 * 512,
),
KernelSpec(
"large_128x256_k32",
128,
256,
32,
"compv3",
"intrawave",
category="large",
min_problem_size=512 * 512,
),
KernelSpec(
"large_128x256_k64",
128,
256,
64,
"compv3",
"intrawave",
category="large",
min_problem_size=512 * 512,
),
KernelSpec(
"large_256x256_k32",
256,
256,
32,
"compv3",
"intrawave",
category="large",
min_problem_size=1024 * 1024,
),
KernelSpec(
"large_256x256_k64",
256,
256,
64,
"compv3",
"intrawave",
category="large",
min_problem_size=1024 * 1024,
),
# ==========================================================================
# COMPUTE-OPTIMIZED - compv4 pipeline for compute-bound workloads
# ==========================================================================
KernelSpec(
"compute_128x128_v4_k32",
128,
128,
32,
"compv4",
"intrawave",
category="compute",
min_problem_size=256 * 256,
),
KernelSpec(
"compute_128x128_v4_k64",
128,
128,
64,
"compv4",
"intrawave",
category="compute",
min_problem_size=256 * 256,
),
KernelSpec(
"compute_256x128_v4",
256,
128,
64,
"compv4",
"intrawave",
category="compute",
min_problem_size=512 * 512,
),
KernelSpec(
"compute_256x256_v4",
256,
256,
64,
"compv4",
"intrawave",
category="compute",
min_problem_size=1024 * 1024,
),
# ==========================================================================
# MEMORY-OPTIMIZED - Good cache utilization for memory-bound workloads
# ==========================================================================
KernelSpec(
"memory_128x128_k16",
128,
128,
16,
"compv3",
"intrawave",
category="memory",
min_problem_size=256 * 256,
),
KernelSpec(
"memory_64x128_k16",
64,
128,
16,
"compv3",
"intrawave",
category="memory",
min_problem_size=128 * 128,
),
]
def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
"""Create KernelConfig from spec"""
warp_m = 16 if spec.tile_m <= 64 else 32
warp_n = 16 if spec.tile_n <= 64 else 32
return KernelConfig(
dtype_a=dtype,
dtype_b=dtype,
dtype_c=dtype,
dtype_acc="fp32",
layout_a="row",
layout_b="col",
layout_c="row",
tile_m=spec.tile_m,
tile_n=spec.tile_n,
tile_k=spec.tile_k,
wave_m=2,
wave_n=2,
wave_k=1,
warp_m=warp_m,
warp_n=warp_n,
warp_k=16,
pipeline=spec.pipeline,
scheduler=spec.scheduler,
epilogue="cshuffle",
gfx_arch=arch,
)
# =============================================================================
# Heuristic Strategies
# =============================================================================
class HeuristicStrategy(Enum):
SIZE_BASED = "size"
COMPUTE_BOUND = "compute"
MEMORY_BOUND = "memory"
LATENCY_FOCUSED = "latency"
def size_based_heuristic(
M: int, N: int, K: int, kernels: List[KernelSpec]
) -> KernelSpec:
"""
Select kernel based on problem size.
- Small problems: Use small tiles for low latency
- Medium problems: Use balanced tiles
- Large problems: Use large tiles for high throughput
Also considers K dimension for tile_k selection.
"""
total_elements = M * N
# Filter by problem size constraints
candidates = [
k for k in kernels if k.min_problem_size <= total_elements <= k.max_problem_size
]
if not candidates:
candidates = kernels # Fall back to all kernels
# Determine target category based on problem size
if total_elements < 256 * 256:
target_category = "small"
elif total_elements < 1024 * 1024:
target_category = "balanced"
else:
target_category = "large"
# Filter by category if possible
category_candidates = [k for k in candidates if k.category == target_category]
if category_candidates:
candidates = category_candidates
# Select best tile_k based on K dimension
# Prefer tile_k that divides K well
def tile_k_score(k):
if K % k.tile_k == 0:
return 0 # Perfect division
return K % k.tile_k # Remainder (lower is better)
# Sort by tile_k fit, then by tile size
candidates.sort(key=lambda k: (tile_k_score(k), -k.tile_m * k.tile_n))
return candidates[0]
def compute_bound_heuristic(
M: int, N: int, K: int, kernels: List[KernelSpec]
) -> KernelSpec:
"""
Select kernel optimized for compute-bound workloads.
Prefers compv4 pipeline and larger tiles.
Selects based on problem size to maximize compute utilization.
"""
total_elements = M * N
# Prefer compute category kernels
compute_kernels = [k for k in kernels if k.category == "compute"]
if not compute_kernels:
# Fall back to compv4 kernels
compute_kernels = [k for k in kernels if k.pipeline == "compv4"]
if not compute_kernels:
compute_kernels = kernels
# Filter by problem size
valid = [k for k in compute_kernels if k.min_problem_size <= total_elements]
if valid:
compute_kernels = valid
# For large problems, prefer larger tiles
if total_elements >= 1024 * 1024:
return max(compute_kernels, key=lambda k: k.tile_m * k.tile_n * k.tile_k)
else:
# For smaller problems, prefer medium tiles
return min(
compute_kernels, key=lambda k: abs(k.tile_m - 128) + abs(k.tile_n - 128)
)
def memory_bound_heuristic(
M: int, N: int, K: int, kernels: List[KernelSpec]
) -> KernelSpec:
"""
Select kernel optimized for memory-bound workloads.
Prefers smaller tile_k for better memory access patterns.
"""
# Prefer memory category kernels first
memory_kernels = [k for k in kernels if k.category == "memory"]
if memory_kernels:
# Select based on problem size
total = M * N
if total < 512 * 512:
return min(memory_kernels, key=lambda k: k.tile_m * k.tile_n)
return max(memory_kernels, key=lambda k: k.tile_m * k.tile_n)
# Fall back to balanced with smaller tile_k
balanced = [k for k in kernels if k.category == "balanced"]
if balanced:
# Prefer smaller tile_k for memory-bound
return min(balanced, key=lambda k: k.tile_k)
# Fall back to medium-sized tile with small tile_k
return min(
kernels, key=lambda k: (k.tile_k, abs(k.tile_m - 128) + abs(k.tile_n - 128))
)
def latency_focused_heuristic(
M: int, N: int, K: int, kernels: List[KernelSpec]
) -> KernelSpec:
"""
Select kernel optimized for low latency.
Prefers smaller tiles and compv4 for faster execution.
"""
# Prefer small category
small_kernels = [k for k in kernels if k.category == "small"]
if small_kernels:
# Among small kernels, prefer compv4 for lower latency
v4_small = [k for k in small_kernels if k.pipeline == "compv4"]
if v4_small:
return v4_small[0]
return small_kernels[0]
# Fall back to smallest tile with compv4 if available
all_v4 = [k for k in kernels if k.pipeline == "compv4"]
if all_v4:
return min(all_v4, key=lambda k: k.tile_m * k.tile_n)
# Fall back to smallest tile
return min(kernels, key=lambda k: k.tile_m * k.tile_n)
HEURISTICS = {
HeuristicStrategy.SIZE_BASED: size_based_heuristic,
HeuristicStrategy.COMPUTE_BOUND: compute_bound_heuristic,
HeuristicStrategy.MEMORY_BOUND: memory_bound_heuristic,
HeuristicStrategy.LATENCY_FOCUSED: latency_focused_heuristic,
}
# =============================================================================
# Main
# =============================================================================
def print_kernel_pool(kernels: List[KernelSpec]):
"""Print available kernels"""
print("\n" + "=" * 75)
print(" KERNEL POOL")
print("=" * 75)
print(f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Pipeline':<10} {'Category':<12}")
print(" " + "-" * 73)
for i, k in enumerate(kernels, 1):
tile = f"{k.tile_m}x{k.tile_n}x{k.tile_k}"
print(f" {i:<3} {k.name:<22} {tile:<14} {k.pipeline:<10} {k.category:<12}")
print(" " + "-" * 73)
def main():
parser = argparse.ArgumentParser(
description="Custom Heuristics Example - intelligent kernel selection",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 08_heuristics.py # Default size-based heuristic
python3 08_heuristics.py --strategy compute # Compute-bound heuristic
python3 08_heuristics.py --strategy memory # Memory-bound heuristic
python3 08_heuristics.py --strategy latency # Latency-focused heuristic
python3 08_heuristics.py --dtype bf16 # BF16 mode
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--strategy",
default="size",
choices=["size", "compute", "memory", "latency"],
help="Heuristic strategy (default: size)",
)
parser.add_argument(
"--arch",
default="gfx942",
help="Target architecture (default: gfx942)",
)
args = parser.parse_args()
reset_for_example()
print("=" * 75)
print("Example 08: Custom Heuristics")
print("=" * 75)
# Map strategy string to enum
strategy_map = {
"size": HeuristicStrategy.SIZE_BASED,
"compute": HeuristicStrategy.COMPUTE_BOUND,
"memory": HeuristicStrategy.MEMORY_BOUND,
"latency": HeuristicStrategy.LATENCY_FOCUSED,
}
strategy = strategy_map[args.strategy]
heuristic_fn = HEURISTICS[strategy]
print(f"\n Strategy: {strategy.value}")
print(f" Data type: {args.dtype}")
# Print kernel pool
print_kernel_pool(KERNEL_POOL)
# =========================================================================
# Test heuristic selection across different problem sizes
# =========================================================================
print("\n" + "=" * 75)
print(" HEURISTIC SELECTION TEST")
print("=" * 75)
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
test_sizes = [
(128, 128, 64), # Small
(256, 256, 128), # Small-medium
(512, 512, 256), # Medium
(1024, 1024, 512), # Medium-large
(2048, 2048, 1024), # Large
]
print(
f"\n {'Size':<20} {'Selected Kernel':<25} {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}"
)
print(" " + "-" * 78)
results = []
for M, N, K in test_sizes:
# Use heuristic to select kernel
selected_spec = heuristic_fn(M, N, K, KERNEL_POOL)
# Create config and setup
config = create_kernel_config(selected_spec, args.dtype, args.arch)
setup = setup_gemm_dispatcher(
config=config,
registry_name=f"heuristic_{selected_spec.name}",
verbose=False,
auto_rebuild=True,
)
size_str = f"{M}x{N}x{K}"
if not setup.success:
print(
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
)
results.append((size_str, selected_spec.name, False, 0, 0))
cleanup_gemm()
continue
dispatcher = setup.dispatcher
if not dispatcher.is_supported(M, N, K):
print(
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'SKIP':<8}"
)
results.append((size_str, selected_spec.name, False, 0, 0))
cleanup_gemm()
continue
# Run GEMM
np.random.seed(42)
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
result = dispatcher.run(A, B, M, N, K)
if not result.success:
print(
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
)
results.append((size_str, selected_spec.name, False, 0, 0))
cleanup_gemm()
continue
# Validate
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
max_err = np.max(np.abs(result.output - C_ref))
passed = max_err < 1e-2
status = "PASS" if passed else "FAIL"
print(
f" {size_str:<20} {selected_spec.name:<25} {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}"
)
results.append(
(size_str, selected_spec.name, passed, result.time_ms, result.tflops)
)
cleanup_gemm()
# =========================================================================
# Summary
# =========================================================================
print("\n" + "=" * 75)
print(" SUMMARY")
print("=" * 75)
passed = sum(1 for r in results if r[2])
failed = len(results) - passed
print(f"\n Strategy: {strategy.value}")
print(f" Results: {passed}/{len(results)} tests passed")
# Show kernel selection distribution
kernel_usage = {}
for r in results:
kernel_usage[r[1]] = kernel_usage.get(r[1], 0) + 1
print("\n Kernel Selection Distribution:")
for kernel, count in sorted(kernel_usage.items(), key=lambda x: -x[1]):
print(f" {kernel}: {count} times")
if results:
valid_results = [r for r in results if r[2]]
if valid_results:
avg_tflops = sum(r[4] for r in valid_results) / len(valid_results)
print(f"\n Average TFLOPS: {avg_tflops:.2f}")
if failed == 0:
print("\n *** ALL TESTS PASSED ***")
else:
print(f"\n *** {failed} TESTS FAILED ***")
print("=" * 75)
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())