mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Adding dispatcher architecture (#3300)
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
This commit is contained in:
committed by
GitHub
parent
44f481a45c
commit
9e049a32a1
518
dispatcher/codegen/preselected_kernels.py
Normal file
518
dispatcher/codegen/preselected_kernels.py
Normal file
@@ -0,0 +1,518 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Preselected, Benchmarked Kernel Configurations
|
||||
|
||||
Curated kernel sets optimized for different workload characteristics:
|
||||
- Compute-friendly: Large tiles, high arithmetic intensity
|
||||
- Memory-friendly: Smaller tiles, better memory access patterns
|
||||
- Latency-friendly: Minimal tiles, low latency for small problems
|
||||
"""
|
||||
|
||||
from functools import partial, lru_cache
|
||||
from typing import List
|
||||
from unified_gemm_codegen import KernelConfig, TileConfig, TraitConfig, GemmVariant
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Base Configurations
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _base_fp16_rcr_compute() -> partial:
|
||||
"""Base configuration for compute-intensive FP16 RCR kernels"""
|
||||
return partial(
|
||||
KernelConfig,
|
||||
tile=None, # Will be overridden
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=256,
|
||||
k_block_per_cu=1,
|
||||
num_wave_groups=1,
|
||||
)
|
||||
|
||||
|
||||
def _base_fp16_rcr_memory() -> partial:
|
||||
"""Base configuration for memory-intensive FP16 RCR kernels"""
|
||||
# Note: Use 'mem' pipeline for interwave scheduler (compv3/compv4/compv5/compv6 only support intrawave)
|
||||
return partial(
|
||||
KernelConfig,
|
||||
tile=None, # Will be overridden
|
||||
trait=TraitConfig(
|
||||
pipeline="mem",
|
||||
epilogue="cshuffle",
|
||||
scheduler="interwave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=128,
|
||||
k_block_per_cu=1,
|
||||
num_wave_groups=1,
|
||||
)
|
||||
|
||||
|
||||
def _base_fp16_rcr_latency() -> partial:
|
||||
"""Base configuration for latency-sensitive FP16 RCR kernels"""
|
||||
return partial(
|
||||
KernelConfig,
|
||||
tile=None, # Will be overridden
|
||||
trait=TraitConfig(
|
||||
pipeline="mem",
|
||||
epilogue="default",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=128,
|
||||
k_block_per_cu=1,
|
||||
num_wave_groups=1,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Preselected FP16 RCR Kernels
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_compute() -> List[KernelConfig]:
|
||||
"""
|
||||
Compute-friendly FP16 RCR kernels
|
||||
|
||||
Optimized for:
|
||||
- Large M, N dimensions (>= 128)
|
||||
- High arithmetic intensity
|
||||
- Good occupancy
|
||||
- Maximum throughput
|
||||
"""
|
||||
base = _base_fp16_rcr_compute()
|
||||
|
||||
return [
|
||||
# Large tiles for maximum compute
|
||||
base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(256, 128, 32, 4, 2, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(128, 256, 32, 2, 4, 1, 32, 32, 16)),
|
||||
# Balanced tiles
|
||||
base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)),
|
||||
# With persistent kernel for large batches
|
||||
base(
|
||||
tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16),
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=False,
|
||||
pad_n=False,
|
||||
pad_k=False,
|
||||
persistent=True,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_memory() -> List[KernelConfig]:
|
||||
"""
|
||||
Memory-friendly FP16 RCR kernels
|
||||
|
||||
Optimized for:
|
||||
- Small to medium M, N dimensions
|
||||
- Memory-bound workloads
|
||||
- Better cache utilization
|
||||
- Lower register pressure
|
||||
"""
|
||||
base = _base_fp16_rcr_memory()
|
||||
|
||||
return [
|
||||
# Small tiles for memory efficiency
|
||||
base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)),
|
||||
base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)),
|
||||
base(tile=TileConfig(16, 64, 32, 1, 2, 1, 16, 16, 16)),
|
||||
base(tile=TileConfig(64, 16, 32, 2, 1, 1, 16, 16, 16)),
|
||||
# Medium tiles
|
||||
base(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(32, 128, 32, 1, 2, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(128, 32, 32, 2, 1, 1, 32, 32, 16)),
|
||||
]
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_latency() -> List[KernelConfig]:
|
||||
"""
|
||||
Latency-friendly FP16 RCR kernels
|
||||
|
||||
Optimized for:
|
||||
- Very small M, N dimensions (< 64)
|
||||
- Minimal launch overhead
|
||||
- Low latency
|
||||
- Quick execution
|
||||
"""
|
||||
base = _base_fp16_rcr_latency()
|
||||
|
||||
return [
|
||||
# Minimal tiles for low latency
|
||||
base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)),
|
||||
base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Preselected Multi-D Kernels
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_multi_d() -> List[KernelConfig]:
|
||||
"""
|
||||
Multi-D GEMM kernels with element-wise fusion
|
||||
|
||||
Common fusions:
|
||||
- MultiDAdd: E = C + D0 + D1
|
||||
- Relu: E = max(C, 0)
|
||||
- Gelu: E = gelu(C)
|
||||
"""
|
||||
base = _base_fp16_rcr_compute()
|
||||
|
||||
configs = []
|
||||
|
||||
# Best-performing tile for fused operations
|
||||
tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)
|
||||
|
||||
# Common element-wise operations
|
||||
for ew_op in ["MultiDAdd", "Relu", "Gelu", "FastGelu"]:
|
||||
for num_d in [1, 2]:
|
||||
configs.append(
|
||||
base(
|
||||
tile=tile,
|
||||
variant=GemmVariant.MULTI_D,
|
||||
elementwise_op=ew_op,
|
||||
num_d_tensors=num_d,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_preshuffle() -> List[KernelConfig]:
|
||||
"""
|
||||
Preshuffle GEMM kernels for weight optimization
|
||||
|
||||
Best for:
|
||||
- Repeated use of same weights
|
||||
- Inference workloads
|
||||
- Batch size > 1
|
||||
"""
|
||||
base = _base_fp16_rcr_compute()
|
||||
|
||||
return [
|
||||
base(
|
||||
tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16),
|
||||
variant=GemmVariant.PRESHUFFLE,
|
||||
preshuffle=True,
|
||||
),
|
||||
base(
|
||||
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
|
||||
variant=GemmVariant.PRESHUFFLE,
|
||||
preshuffle=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Unified Preselected Sets
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_all() -> List[KernelConfig]:
|
||||
"""All preselected FP16 RCR kernels"""
|
||||
return (
|
||||
preselected_fp16_rcr_compute()
|
||||
+ preselected_fp16_rcr_memory()
|
||||
+ preselected_fp16_rcr_latency()
|
||||
+ preselected_fp16_rcr_multi_d()
|
||||
+ preselected_fp16_rcr_preshuffle()
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_essential() -> List[KernelConfig]:
|
||||
"""
|
||||
Essential FP16 RCR kernels - minimal set for most workloads
|
||||
|
||||
Covers:
|
||||
- 90% of common GEMM sizes
|
||||
- Key fusion operations
|
||||
- Balanced performance
|
||||
"""
|
||||
base_compute = _base_fp16_rcr_compute()
|
||||
base_memory = _base_fp16_rcr_memory()
|
||||
|
||||
return [
|
||||
# Top compute kernels
|
||||
base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
|
||||
base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
|
||||
# Top memory kernels
|
||||
base_memory(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)),
|
||||
base_memory(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)),
|
||||
# Essential fusions
|
||||
base_compute(
|
||||
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
|
||||
variant=GemmVariant.MULTI_D,
|
||||
elementwise_op="Relu",
|
||||
num_d_tensors=1,
|
||||
),
|
||||
base_compute(
|
||||
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
|
||||
variant=GemmVariant.MULTI_D,
|
||||
elementwise_op="Gelu",
|
||||
num_d_tensors=1,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Default Fallback
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def default_kernel() -> KernelConfig:
|
||||
"""
|
||||
Default fallback kernel - guaranteed to work
|
||||
|
||||
Known-good configuration tested on gfx942
|
||||
"""
|
||||
return KernelConfig(
|
||||
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=256,
|
||||
k_block_per_cu=1,
|
||||
num_wave_groups=1,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# BF16 Preselected Sets
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_bf16_rcr_essential() -> List[KernelConfig]:
|
||||
"""Essential BF16 RCR kernels"""
|
||||
base_compute = partial(
|
||||
KernelConfig,
|
||||
tile=None,
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=256,
|
||||
)
|
||||
|
||||
return [
|
||||
base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
|
||||
base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# INT8 Preselected Sets
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_int8_rcr_essential() -> List[KernelConfig]:
|
||||
"""Essential INT8 RCR kernels for quantized inference"""
|
||||
base = partial(
|
||||
KernelConfig,
|
||||
tile=None,
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=256,
|
||||
)
|
||||
|
||||
return [
|
||||
base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FP8 Preselected Sets
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp8_rcr_essential() -> List[KernelConfig]:
|
||||
"""Essential FP8 RCR kernels for AI training"""
|
||||
base = partial(
|
||||
KernelConfig,
|
||||
tile=None,
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=256,
|
||||
)
|
||||
|
||||
return [
|
||||
base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mixed Precision Preselected Sets
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_mixed_precision() -> List[KernelConfig]:
|
||||
"""Mixed-precision kernels (FP16 inputs, FP32 output)"""
|
||||
base = partial(
|
||||
KernelConfig,
|
||||
tile=None,
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=256,
|
||||
)
|
||||
|
||||
return [
|
||||
base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Registry
|
||||
# ============================================================================
|
||||
|
||||
PRESELECTED_SETS = {
|
||||
# FP16 sets
|
||||
"fp16_rcr_compute": preselected_fp16_rcr_compute,
|
||||
"fp16_rcr_memory": preselected_fp16_rcr_memory,
|
||||
"fp16_rcr_latency": preselected_fp16_rcr_latency,
|
||||
"fp16_rcr_multi_d": preselected_fp16_rcr_multi_d,
|
||||
"fp16_rcr_preshuffle": preselected_fp16_rcr_preshuffle,
|
||||
"fp16_rcr_all": preselected_fp16_rcr_all,
|
||||
"fp16_rcr_essential": preselected_fp16_rcr_essential,
|
||||
# BF16 sets
|
||||
"bf16_rcr_essential": preselected_bf16_rcr_essential,
|
||||
# INT8 sets
|
||||
"int8_rcr_essential": preselected_int8_rcr_essential,
|
||||
# FP8 sets
|
||||
"fp8_rcr_essential": preselected_fp8_rcr_essential,
|
||||
# Mixed precision
|
||||
"mixed_precision": preselected_mixed_precision,
|
||||
}
|
||||
|
||||
|
||||
def get_preselected_set(name: str) -> List[KernelConfig]:
|
||||
"""Get a preselected kernel set by name"""
|
||||
if name not in PRESELECTED_SETS:
|
||||
raise ValueError(
|
||||
f"Unknown preselected set: {name}. Available: {list(PRESELECTED_SETS.keys())}"
|
||||
)
|
||||
return PRESELECTED_SETS[name]()
|
||||
|
||||
|
||||
def list_preselected_sets() -> List[str]:
|
||||
"""List all available preselected sets"""
|
||||
return list(PRESELECTED_SETS.keys())
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI for testing
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="List preselected kernel configurations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--set",
|
||||
type=str,
|
||||
default="fp16_rcr_essential",
|
||||
choices=list_preselected_sets(),
|
||||
help="Preselected set to display",
|
||||
)
|
||||
parser.add_argument("--count-only", action="store_true", help="Only show count")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
configs = get_preselected_set(args.set)
|
||||
|
||||
if args.count_only:
|
||||
print(f"{args.set}: {len(configs)} kernels")
|
||||
else:
|
||||
print(f"Preselected set: {args.set}")
|
||||
print(f"Total kernels: {len(configs)}\n")
|
||||
for i, cfg in enumerate(configs, 1):
|
||||
print(f"{i}. {cfg.variant.value}")
|
||||
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
|
||||
print(f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}")
|
||||
if cfg.variant == GemmVariant.MULTI_D:
|
||||
print(
|
||||
f" Element-wise: {cfg.elementwise_op}, D tensors: {cfg.num_d_tensors}"
|
||||
)
|
||||
print()
|
||||
Reference in New Issue
Block a user