Files
composable_kernel/dispatcher/examples/gemm/python/01_basic_gemm.py
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

332 lines
11 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 01: Basic GEMM with Multiple Kernels
Demonstrates:
1. Declaring multiple kernel configurations
2. Printing all registered kernels
3. Running each kernel and validating output
4. Comparing performance across kernels
Complexity: ★★☆☆☆
Usage:
python3 01_basic_gemm.py
python3 01_basic_gemm.py --help
python3 01_basic_gemm.py --dtype bf16
python3 01_basic_gemm.py --size 2048
"""
import sys
import argparse
from pathlib import Path
from dataclasses import dataclass
from typing import List
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
)
@dataclass
class KernelSpec:
"""Specification for a kernel configuration"""
name: str
tile_m: int
tile_n: int
tile_k: int
pipeline: str = "compv3"
scheduler: str = "intrawave"
# Define multiple kernel configurations to test (50+ kernels)
KERNEL_SPECS = [
# Small tiles - compv3
KernelSpec("small_64x64_k32", 64, 64, 32, "compv3"),
KernelSpec("small_64x64_k64", 64, 64, 64, "compv3"),
# Small tiles - compv4
KernelSpec("small_64x64_v4_k32", 64, 64, 32, "compv4"),
KernelSpec("small_64x64_v4_k64", 64, 64, 64, "compv4"),
# Medium tiles - compv3
KernelSpec("med_128x128_k32", 128, 128, 32, "compv3"),
KernelSpec("med_128x128_k64", 128, 128, 64, "compv3"),
KernelSpec("med_128x128_k128", 128, 128, 128, "compv3"),
# Medium tiles - compv4
KernelSpec("med_128x128_v4_k32", 128, 128, 32, "compv4"),
KernelSpec("med_128x128_v4_k64", 128, 128, 64, "compv4"),
KernelSpec("med_128x128_v4_k128", 128, 128, 128, "compv4"),
# Rectangular tiles - compv3
KernelSpec("rect_64x128_k32", 64, 128, 32, "compv3"),
KernelSpec("rect_64x128_k64", 64, 128, 64, "compv3"),
KernelSpec("rect_128x64_k32", 128, 64, 32, "compv3"),
KernelSpec("rect_128x64_k64", 128, 64, 64, "compv3"),
# Rectangular tiles - compv4
KernelSpec("rect_64x128_v4_k32", 64, 128, 32, "compv4"),
KernelSpec("rect_64x128_v4_k64", 64, 128, 64, "compv4"),
KernelSpec("rect_128x64_v4_k32", 128, 64, 32, "compv4"),
KernelSpec("rect_128x64_v4_k64", 128, 64, 64, "compv4"),
# Large tiles - compv3
KernelSpec("large_256x128_k32", 256, 128, 32, "compv3"),
KernelSpec("large_256x128_k64", 256, 128, 64, "compv3"),
KernelSpec("large_128x256_k32", 128, 256, 32, "compv3"),
KernelSpec("large_128x256_k64", 128, 256, 64, "compv3"),
KernelSpec("large_256x256_k32", 256, 256, 32, "compv3"),
KernelSpec("large_256x256_k64", 256, 256, 64, "compv3"),
# Large tiles - compv4
KernelSpec("large_256x128_v4_k32", 256, 128, 32, "compv4"),
KernelSpec("large_256x128_v4_k64", 256, 128, 64, "compv4"),
KernelSpec("large_128x256_v4_k32", 128, 256, 32, "compv4"),
KernelSpec("large_128x256_v4_k64", 128, 256, 64, "compv4"),
KernelSpec("large_256x256_v4_k32", 256, 256, 32, "compv4"),
KernelSpec("large_256x256_v4_k64", 256, 256, 64, "compv4"),
# Interwave scheduler variants
KernelSpec("int_64x64_k32", 64, 64, 32, "compv3", "interwave"),
KernelSpec("int_128x128_k32", 128, 128, 32, "compv3", "interwave"),
KernelSpec("int_128x128_k64", 128, 128, 64, "compv3", "interwave"),
KernelSpec("int_256x128_k32", 256, 128, 32, "compv3", "interwave"),
# More tile_k variations - compv3
KernelSpec("med_128x128_k16", 128, 128, 16, "compv3"),
KernelSpec("rect_64x128_k16", 64, 128, 16, "compv3"),
KernelSpec("rect_128x64_k16", 128, 64, 16, "compv3"),
# More tile_k variations - compv4
KernelSpec("med_128x128_v4_k16", 128, 128, 16, "compv4"),
KernelSpec("rect_64x128_v4_k16", 64, 128, 16, "compv4"),
KernelSpec("rect_128x64_v4_k16", 128, 64, 16, "compv4"),
# Additional rectangular
KernelSpec("rect_32x64_k32", 32, 64, 32, "compv3"),
KernelSpec("rect_64x32_k32", 64, 32, 32, "compv3"),
KernelSpec("rect_32x128_k32", 32, 128, 32, "compv3"),
KernelSpec("rect_128x32_k32", 128, 32, 32, "compv3"),
# Additional compv4 variants
KernelSpec("rect_32x64_v4_k32", 32, 64, 32, "compv4"),
KernelSpec("rect_64x32_v4_k32", 64, 32, 32, "compv4"),
KernelSpec("rect_32x128_v4_k32", 32, 128, 32, "compv4"),
KernelSpec("rect_128x32_v4_k32", 128, 32, 32, "compv4"),
]
def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
"""Create a KernelConfig from a spec"""
# Adjust warp tiles based on tile size
if spec.tile_m <= 64:
warp_m, warp_n = 16, 16
else:
warp_m, warp_n = 32, 32
return KernelConfig(
dtype_a=dtype,
dtype_b=dtype,
dtype_c=dtype,
dtype_acc="fp32",
layout_a="row",
layout_b="col",
layout_c="row",
tile_m=spec.tile_m,
tile_n=spec.tile_n,
tile_k=spec.tile_k,
wave_m=2,
wave_n=2,
wave_k=1,
warp_m=warp_m,
warp_n=warp_n,
warp_k=16,
pipeline=spec.pipeline,
scheduler=spec.scheduler,
epilogue="cshuffle",
gfx_arch=arch,
)
def print_kernel_table(specs: List[KernelSpec], dtype: str):
"""Print a formatted table of kernel configurations"""
print("\n" + "=" * 70)
print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)")
print("=" * 70)
print(f"\n {'#':<3} {'Name':<18} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}")
print(" " + "-" * 68)
for i, spec in enumerate(specs, 1):
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
print(
f" {i:<3} {spec.name:<18} {tile:<14} {spec.pipeline:<10} {spec.scheduler:<12}"
)
print(" " + "-" * 68)
print(f" Data type: {dtype}")
def main():
parser = argparse.ArgumentParser(
description="Basic GEMM Example with Multiple Kernels",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 01_basic_gemm.py # Default FP16 with 4 kernels
python3 01_basic_gemm.py --dtype bf16 # BF16 mode
python3 01_basic_gemm.py --size 2048 # Larger problem size
python3 01_basic_gemm.py --num-kernels 2 # Test only 2 kernels
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--arch",
default="gfx942",
help="Target architecture (default: gfx942)",
)
parser.add_argument(
"--size",
type=int,
default=512,
help="Problem size MxNxK (default: 512)",
)
parser.add_argument(
"--num-kernels",
type=int,
default=0,
help="Number of kernels to test (0 = all)",
)
args = parser.parse_args()
reset_for_example()
print("=" * 70)
print("Example 01: Basic GEMM with Multiple Kernels")
print("=" * 70)
# Select kernels to test
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
# =========================================================================
# Step 1: Print all kernel configurations
# =========================================================================
print_kernel_table(specs, args.dtype)
# =========================================================================
# Step 2: Setup and test each kernel
# =========================================================================
print("\n" + "=" * 70)
print(" RUNNING KERNELS")
print("=" * 70)
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
M, N, K = args.size, args.size, args.size
results = []
print(f"\n Problem size: {M}x{N}x{K}\n")
print(
f" {'#':<3} {'Name':<18} {'Tile':<14} {'Time (ms)':>10} {'TFLOPS':>10} {'Max Err':>10} {'Status':<8}"
)
print(" " + "-" * 78)
for i, spec in enumerate(specs, 1):
# Create unique test data per kernel
np.random.seed(42 + i * 1000)
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
# Create config and setup dispatcher
config = create_kernel_config(spec, args.dtype, args.arch)
setup = setup_gemm_dispatcher(
config=config,
registry_name=f"kernel_{spec.name}",
verbose=False,
auto_rebuild=True,
)
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
if not setup.success:
print(
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
)
results.append((spec.name, False, 0, 0, 0))
cleanup_gemm()
continue
dispatcher = setup.dispatcher
# Check if size is supported
if not dispatcher.is_supported(M, N, K):
print(
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'SKIP':<8}"
)
results.append((spec.name, False, 0, 0, 0))
cleanup_gemm()
continue
# Run GEMM
result = dispatcher.run(A, B, M, N, K)
if not result.success:
print(
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
)
results.append((spec.name, False, 0, 0, 0))
cleanup_gemm()
continue
# Validate against NumPy reference
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
max_err = np.max(np.abs(result.output - C_ref))
# Check if within tolerance
passed = max_err < 1e-2
status = "PASS" if passed else "FAIL"
print(
f" {i:<3} {spec.name:<18} {tile:<14} {result.time_ms:>10.4f} {result.tflops:>10.2f} {max_err:>10.2e} {status:<8}"
)
results.append((spec.name, passed, result.time_ms, result.tflops, max_err))
cleanup_gemm()
# =========================================================================
# Step 3: Summary
# =========================================================================
print("\n" + "=" * 70)
print(" SUMMARY")
print("=" * 70)
passed = sum(1 for r in results if r[1])
failed = len(results) - passed
print(f"\n Results: {passed}/{len(results)} kernels passed")
print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}")
if results:
valid_results = [r for r in results if r[1]]
if valid_results:
best = max(valid_results, key=lambda x: x[3])
print(f"\n Best kernel: {best[0]} ({best[3]:.2f} TFLOPS)")
if failed == 0:
print("\n *** ALL KERNELS PASSED ***")
else:
print(f"\n *** {failed} KERNELS FAILED ***")
print("=" * 70)
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())