Files
composable_kernel/dispatcher/codegen/arch_specs_generated.py
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

359 lines
11 KiB
Python

# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY!
Generated from: arch_specs.json
Generated at: 2026-01-05T19:34:01.224422
To update this file:
1. Edit arch_specs.json
2. Run: python generate_arch_specs.py
This module provides architecture-specific configurations for kernel filtering.
"""
from typing import Dict, List, Set, Tuple
# =============================================================================
# Architecture Data (Generated from arch_specs.json)
# =============================================================================
# GPU architecture to family mapping
ARCH_FAMILY_MAP: Dict[str, str] = {
"gfx908": "cdna1",
"gfx90a": "cdna2",
"gfx942": "cdna3",
"gfx950": "cdna4",
"gfx1100": "rdna3",
"gfx1200": "rdna4",
"gfx1201": "rdna4",
}
# Element size in bytes for each data type
ELEMENT_SIZE_MAP: Dict[str, float] = {
"fp16": 2,
"bf16": 2,
"fp32": 4,
"fp64": 8,
"fp8": 1,
"bf8": 1,
"int8": 1,
"int4": 0.5,
"pk_fp4": 0.5,
"int32": 4,
}
# Supported warp configurations per architecture [warp_m, warp_n, warp_k]
WARP_SUPPORTED_COMBINATIONS: Dict[str, List[List[int]]] = {
"gfx908": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx1100": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
"gfx1200": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
"gfx1201": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
}
# Supported warp tile combinations: arch -> dtype_key -> [[warp_tile_m, n, k], ...]
WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {
"gfx908": {
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
},
"gfx90a": {
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
"fp16_fp16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]],
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
},
"gfx942": {
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
"fp16_fp16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"fp8_bf8_fp32": [[32, 32, 16], [16, 16, 32], [32, 32, 32]],
"bf8_fp8_fp32": [[32, 32, 16]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
},
"gfx950": {
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
"fp16_fp16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp32": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 64],
[16, 16, 128],
[32, 32, 64],
],
"fp8_bf8_fp32": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 128],
[32, 32, 64],
],
"bf8_fp8_fp32": [[32, 32, 16], [16, 16, 128], [32, 32, 64]],
"bf8_bf8_fp32": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 64],
[16, 16, 128],
[32, 32, 64],
],
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
"pk_fp4_pk_fp4_fp32": [[16, 16, 128]],
},
"gfx1100": {
"fp16_fp16_fp32": [[16, 16, 16]],
"bf16_bf16_fp32": [[16, 16, 16]],
"int8_int8_int32": [[16, 16, 16]],
},
"gfx1200": {
"fp16_fp16_fp32": [[16, 16, 16]],
"bf16_bf16_fp32": [[16, 16, 16]],
"fp8_fp8_fp32": [[16, 16, 16]],
"bf8_bf8_fp32": [[16, 16, 16]],
"fp8_bf8_fp32": [[16, 16, 16]],
"bf8_fp8_fp32": [[16, 16, 16]],
"int8_int8_int32": [[16, 16, 16]],
},
"gfx1201": {
"fp16_fp16_fp32": [[16, 16, 16]],
"bf16_bf16_fp32": [[16, 16, 16]],
"fp8_fp8_fp32": [[16, 16, 16]],
"bf8_bf8_fp32": [[16, 16, 16]],
"fp8_bf8_fp32": [[16, 16, 16]],
"bf8_fp8_fp32": [[16, 16, 16]],
"int8_int8_int32": [[16, 16, 16]],
},
}
# Preshuffle-specific warp tile combinations (subset of standard GEMM)
PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {
"gfx90a": {
"fp16_fp16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"bf16_bf16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]],
},
"gfx942": {
"fp16_fp16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"bf16_bf16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
},
"gfx950": {
"fp16_fp16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"bf16_bf16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"fp8_fp8_fp32": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 64],
[16, 16, 128],
[32, 32, 64],
],
"bf8_bf8_fp32": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 64],
[16, 16, 32],
[16, 16, 128],
[32, 32, 64],
],
},
}
# Preshuffle-supported pipelines
PRESHUFFLE_PIPELINES: List[str] = ["preshufflev2"]
# LDS capacity limits per pipeline type (in bytes)
LDS_CAPACITY_LIMITS: Dict[str, int] = {
"mem": 65536,
"compv1": 65536,
"compv2": 65536,
"compv3": 65536,
"compv4": 32768,
"compv5": 65536,
"preshufflev1": 32768,
"preshufflev2": 32768,
"default": 65536,
}
# Unsupported trait combinations: (pipeline, epilogue, scheduler)
TRAIT_UNSUPPORTED_COMBINATIONS: Set[Tuple[str, str, str]] = {
("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave"),
("compv5", "cshuffle", "interwave"),
("compv5", "default", "interwave"),
("compv6", "cshuffle", "interwave"),
("compv6", "default", "interwave"),
("comp_async", "cshuffle", "interwave"),
("comp_async", "default", "interwave"),
}
# Valid dtype combinations: (A_dtype, B_dtype) -> acc_dtype and notes
DTYPE_COMBINATIONS: Dict[str, Dict[str, str]] = {
"fp32_fp32": {"acc": "fp32", "notes": "Full precision"},
"fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"},
"bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"},
"fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"},
"fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"},
"bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"},
"bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"},
"int8_int8": {"acc": "int32", "notes": "Integer GEMM"},
"pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"},
}
# =============================================================================
# Helper Functions
# =============================================================================
def get_supported_archs() -> List[str]:
"""Get list of all supported GPU architectures."""
return list(ARCH_FAMILY_MAP.keys())
def get_arch_family(gpu_arch: str) -> str:
"""Get the GPU family for an architecture."""
return ARCH_FAMILY_MAP.get(gpu_arch.lower(), "unknown")
def get_element_size(dtype: str) -> float:
"""Get element size in bytes for a data type."""
return ELEMENT_SIZE_MAP.get(dtype.lower(), 2.0)
def get_warp_configs(gpu_arch: str) -> List[List[int]]:
"""Get supported warp configurations for an architecture."""
return WARP_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), [])
def get_warp_tile_combos(gpu_arch: str, dtype_key: str) -> List[List[int]]:
"""Get supported warp tile combinations for arch and data types."""
gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), {})
return gpu_combos.get(dtype_key.lower(), [])
def get_lds_limit(pipeline: str) -> int:
"""Get LDS capacity limit for a pipeline type."""
return LDS_CAPACITY_LIMITS.get(pipeline.lower(), LDS_CAPACITY_LIMITS["default"])
def is_trait_combo_unsupported(pipeline: str, epilogue: str, scheduler: str) -> bool:
"""Check if a trait combination is unsupported."""
return (
pipeline.lower(),
epilogue.lower(),
scheduler.lower(),
) in TRAIT_UNSUPPORTED_COMBINATIONS
def get_dtype_info(dtype_a: str, dtype_b: str) -> Dict[str, str]:
"""Get accumulator type and notes for a dtype combination."""
key = f"{dtype_a.lower()}_{dtype_b.lower()}"
return DTYPE_COMBINATIONS.get(key, {"acc": "fp32", "notes": "unknown"})
def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool:
"""Check if a dtype combination is valid."""
key = f"{dtype_a.lower()}_{dtype_b.lower()}"
return key in DTYPE_COMBINATIONS
def get_valid_dtype_combos() -> List[str]:
"""Get list of all valid dtype combinations."""
return list(DTYPE_COMBINATIONS.keys())