Files
composable_kernel/dispatcher/codegen/preselected_kernels.py
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

519 lines
15 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Preselected, Benchmarked Kernel Configurations
Curated kernel sets optimized for different workload characteristics:
- Compute-friendly: Large tiles, high arithmetic intensity
- Memory-friendly: Smaller tiles, better memory access patterns
- Latency-friendly: Minimal tiles, low latency for small problems
"""
from functools import partial, lru_cache
from typing import List
from unified_gemm_codegen import KernelConfig, TileConfig, TraitConfig, GemmVariant
# ============================================================================
# Base Configurations
# ============================================================================
def _base_fp16_rcr_compute() -> partial:
"""Base configuration for compute-intensive FP16 RCR kernels"""
return partial(
KernelConfig,
tile=None, # Will be overridden
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=256,
k_block_per_cu=1,
num_wave_groups=1,
)
def _base_fp16_rcr_memory() -> partial:
"""Base configuration for memory-intensive FP16 RCR kernels"""
# Note: Use 'mem' pipeline for interwave scheduler (compv3/compv4/compv5/compv6 only support intrawave)
return partial(
KernelConfig,
tile=None, # Will be overridden
trait=TraitConfig(
pipeline="mem",
epilogue="cshuffle",
scheduler="interwave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=128,
k_block_per_cu=1,
num_wave_groups=1,
)
def _base_fp16_rcr_latency() -> partial:
"""Base configuration for latency-sensitive FP16 RCR kernels"""
return partial(
KernelConfig,
tile=None, # Will be overridden
trait=TraitConfig(
pipeline="mem",
epilogue="default",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=128,
k_block_per_cu=1,
num_wave_groups=1,
)
# ============================================================================
# Preselected FP16 RCR Kernels
# ============================================================================
@lru_cache(None)
def preselected_fp16_rcr_compute() -> List[KernelConfig]:
"""
Compute-friendly FP16 RCR kernels
Optimized for:
- Large M, N dimensions (>= 128)
- High arithmetic intensity
- Good occupancy
- Maximum throughput
"""
base = _base_fp16_rcr_compute()
return [
# Large tiles for maximum compute
base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)),
base(tile=TileConfig(256, 128, 32, 4, 2, 1, 32, 32, 16)),
base(tile=TileConfig(128, 256, 32, 2, 4, 1, 32, 32, 16)),
# Balanced tiles
base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)),
# With persistent kernel for large batches
base(
tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16),
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=False,
pad_n=False,
pad_k=False,
persistent=True,
),
),
]
@lru_cache(None)
def preselected_fp16_rcr_memory() -> List[KernelConfig]:
"""
Memory-friendly FP16 RCR kernels
Optimized for:
- Small to medium M, N dimensions
- Memory-bound workloads
- Better cache utilization
- Lower register pressure
"""
base = _base_fp16_rcr_memory()
return [
# Small tiles for memory efficiency
base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)),
base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)),
base(tile=TileConfig(16, 64, 32, 1, 2, 1, 16, 16, 16)),
base(tile=TileConfig(64, 16, 32, 2, 1, 1, 16, 16, 16)),
# Medium tiles
base(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)),
base(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)),
base(tile=TileConfig(32, 128, 32, 1, 2, 1, 32, 32, 16)),
base(tile=TileConfig(128, 32, 32, 2, 1, 1, 32, 32, 16)),
]
@lru_cache(None)
def preselected_fp16_rcr_latency() -> List[KernelConfig]:
"""
Latency-friendly FP16 RCR kernels
Optimized for:
- Very small M, N dimensions (< 64)
- Minimal launch overhead
- Low latency
- Quick execution
"""
base = _base_fp16_rcr_latency()
return [
# Minimal tiles for low latency
base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)),
base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)),
]
# ============================================================================
# Preselected Multi-D Kernels
# ============================================================================
@lru_cache(None)
def preselected_fp16_rcr_multi_d() -> List[KernelConfig]:
"""
Multi-D GEMM kernels with element-wise fusion
Common fusions:
- MultiDAdd: E = C + D0 + D1
- Relu: E = max(C, 0)
- Gelu: E = gelu(C)
"""
base = _base_fp16_rcr_compute()
configs = []
# Best-performing tile for fused operations
tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)
# Common element-wise operations
for ew_op in ["MultiDAdd", "Relu", "Gelu", "FastGelu"]:
for num_d in [1, 2]:
configs.append(
base(
tile=tile,
variant=GemmVariant.MULTI_D,
elementwise_op=ew_op,
num_d_tensors=num_d,
)
)
return configs
@lru_cache(None)
def preselected_fp16_rcr_preshuffle() -> List[KernelConfig]:
"""
Preshuffle GEMM kernels for weight optimization
Best for:
- Repeated use of same weights
- Inference workloads
- Batch size > 1
"""
base = _base_fp16_rcr_compute()
return [
base(
tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16),
variant=GemmVariant.PRESHUFFLE,
preshuffle=True,
),
base(
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
variant=GemmVariant.PRESHUFFLE,
preshuffle=True,
),
]
# ============================================================================
# Unified Preselected Sets
# ============================================================================
@lru_cache(None)
def preselected_fp16_rcr_all() -> List[KernelConfig]:
"""All preselected FP16 RCR kernels"""
return (
preselected_fp16_rcr_compute()
+ preselected_fp16_rcr_memory()
+ preselected_fp16_rcr_latency()
+ preselected_fp16_rcr_multi_d()
+ preselected_fp16_rcr_preshuffle()
)
@lru_cache(None)
def preselected_fp16_rcr_essential() -> List[KernelConfig]:
"""
Essential FP16 RCR kernels - minimal set for most workloads
Covers:
- 90% of common GEMM sizes
- Key fusion operations
- Balanced performance
"""
base_compute = _base_fp16_rcr_compute()
base_memory = _base_fp16_rcr_memory()
return [
# Top compute kernels
base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
# Top memory kernels
base_memory(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)),
base_memory(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)),
# Essential fusions
base_compute(
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
variant=GemmVariant.MULTI_D,
elementwise_op="Relu",
num_d_tensors=1,
),
base_compute(
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
variant=GemmVariant.MULTI_D,
elementwise_op="Gelu",
num_d_tensors=1,
),
]
# ============================================================================
# Default Fallback
# ============================================================================
def default_kernel() -> KernelConfig:
"""
Default fallback kernel - guaranteed to work
Known-good configuration tested on gfx942
"""
return KernelConfig(
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=256,
k_block_per_cu=1,
num_wave_groups=1,
)
# ============================================================================
# BF16 Preselected Sets
# ============================================================================
@lru_cache(None)
def preselected_bf16_rcr_essential() -> List[KernelConfig]:
"""Essential BF16 RCR kernels"""
base_compute = partial(
KernelConfig,
tile=None,
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=256,
)
return [
base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
]
# ============================================================================
# INT8 Preselected Sets
# ============================================================================
@lru_cache(None)
def preselected_int8_rcr_essential() -> List[KernelConfig]:
"""Essential INT8 RCR kernels for quantized inference"""
base = partial(
KernelConfig,
tile=None,
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=256,
)
return [
base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)),
base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)),
]
# ============================================================================
# FP8 Preselected Sets
# ============================================================================
@lru_cache(None)
def preselected_fp8_rcr_essential() -> List[KernelConfig]:
"""Essential FP8 RCR kernels for AI training"""
base = partial(
KernelConfig,
tile=None,
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=256,
)
return [
base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)),
base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)),
]
# ============================================================================
# Mixed Precision Preselected Sets
# ============================================================================
@lru_cache(None)
def preselected_mixed_precision() -> List[KernelConfig]:
"""Mixed-precision kernels (FP16 inputs, FP32 output)"""
base = partial(
KernelConfig,
tile=None,
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=256,
)
return [
base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
]
# ============================================================================
# Registry
# ============================================================================
PRESELECTED_SETS = {
# FP16 sets
"fp16_rcr_compute": preselected_fp16_rcr_compute,
"fp16_rcr_memory": preselected_fp16_rcr_memory,
"fp16_rcr_latency": preselected_fp16_rcr_latency,
"fp16_rcr_multi_d": preselected_fp16_rcr_multi_d,
"fp16_rcr_preshuffle": preselected_fp16_rcr_preshuffle,
"fp16_rcr_all": preselected_fp16_rcr_all,
"fp16_rcr_essential": preselected_fp16_rcr_essential,
# BF16 sets
"bf16_rcr_essential": preselected_bf16_rcr_essential,
# INT8 sets
"int8_rcr_essential": preselected_int8_rcr_essential,
# FP8 sets
"fp8_rcr_essential": preselected_fp8_rcr_essential,
# Mixed precision
"mixed_precision": preselected_mixed_precision,
}
def get_preselected_set(name: str) -> List[KernelConfig]:
"""Get a preselected kernel set by name"""
if name not in PRESELECTED_SETS:
raise ValueError(
f"Unknown preselected set: {name}. Available: {list(PRESELECTED_SETS.keys())}"
)
return PRESELECTED_SETS[name]()
def list_preselected_sets() -> List[str]:
"""List all available preselected sets"""
return list(PRESELECTED_SETS.keys())
# ============================================================================
# CLI for testing
# ============================================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="List preselected kernel configurations"
)
parser.add_argument(
"--set",
type=str,
default="fp16_rcr_essential",
choices=list_preselected_sets(),
help="Preselected set to display",
)
parser.add_argument("--count-only", action="store_true", help="Only show count")
args = parser.parse_args()
configs = get_preselected_set(args.set)
if args.count_only:
print(f"{args.set}: {len(configs)} kernels")
else:
print(f"Preselected set: {args.set}")
print(f"Total kernels: {len(configs)}\n")
for i, cfg in enumerate(configs, 1):
print(f"{i}. {cfg.variant.value}")
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
print(f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}")
if cfg.variant == GemmVariant.MULTI_D:
print(
f" Element-wise: {cfg.elementwise_op}, D tensors: {cfg.num_d_tensors}"
)
print()