Files
composable_kernel/dispatcher/codegen/fmha/validation.py
Vidyasagar Ananthan 86591de476 [rocm-libraries] ROCm/rocm-libraries#5260 (commit a1834d2)
[CK] [CK_Tile] Add FMHA scaffolding to CK kernel dispatcher
 (#5260)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

The CK Tile dispatcher currently supports GEMM and Grouped Convolution
but has no support for Fused Multi-Head Attention (FMHA). The
example/ck_tile/01_fmha folder contains a comprehensive FMHA
implementation with forward, backward, split-KV, paged-KV, append-KV,
and batch-prefill kernels across multiple GPU architectures — but there
is no unified dispatch layer for it. This PR ports the FMHA stack into
the dispatcher, following the same architectural patterns established by
GEMM and Grouped Convolution, enabling runtime kernel selection, JIT
compilation from Python, and a declarative C++ example flow. Autotuning
heuristics to follow.

## Technical Details

This PR adds FMHA scaffolding to the CK dispatcher framework, mirroring
GEMM's layered architecture. Seven new C++ runtime headers provide type
definitions (coexisting with upstream headers via __has_include,
requiring zero modifications to example/ck_tile/01_fmha/), a problem
builder with 18+ setters, Signature + Algorithm kernel key matching, a
virtual kernel instance, a DECL_FMHA_KERNEL_SET macro with wildcard
support and named tile/wave/warp setters, arch-aware registry with JSON
export, and a dispatcher with seqtune-aware selection, configurable
timing, and multi-stage execution plans for split-KV (two-stage) and
backward (three-stage). The codegen pipeline is driven by a
fmha_arch_specs.json capturing per-arch tile tables and pipeline
constraints for five architectures (gfx90a/942/950/1100/1201), migrated
from hardcoded logic in 01_fmha/codegen/, with supporting modules for
C++ symbol mappings, validation rules, and named receipt profiles
(ck_default, flash, pytorch, aiter, fp32, fp8). Python integration
(fmha_utils.py) mirrors the C++ layer with JIT compilation, parallel
multi-kernel builds, HIP memory management via ctypes, tolerance-based
validation, and a NumPy CPU reference with GQA support. Twenty-seven C++
and thirty-two Python examples cover the full feature surface — forward,
split-KV, masks, bias, dropout, GQA, backward, append-KV, batch prefill,
fp8, logits soft cap, sink tokens, and parameter sweeps — all
JIT-compiled on the fly.

## Test Plan

Seven test files cover the runtime types, codegen, and end-to-end
correctness. C++ unit tests validate the problem builder, dispatcher
planning (single-stage for forward/paged-KV/append-KV; multi-stage for
split-KV and backward), registry operations, and the kernel-set
declaration macro. Python unit tests verify codegen emission, profile
filtering, and 15 validation rules for masks, hdim constraints, and
pipeline requirements. GPU execution validation in 01_basic_fmha
--validate reports zero errors across 65,536 elements with max absolute
error of 7.29e-05. A gold-standard parity suite (test_fmha_parity.py)
runs 14 configurations through both the upstream tile_example_fmha_fwd
and the dispatcher, comparing exit codes to confirm behavioral parity —
all 14 match.

## Test Result

The C++ smoke test builds and passes all 9 compiled examples, and a
Python JIT sweep (29_sweep_seqlen.py) passes 7/7 configurations reaching
up to 375 TFLOPS at seqlen 2048.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-17 07:30:33 +00:00

922 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
FMHA validation and kernel specifications.
Architecture-specific data (dtypes, pipelines, hdims, tile tables) is stored in
``fmha_arch_specs.json`` so that it can be edited without touching Python code.
Common GPU hardware data (element sizes, warp size, LDS capacity) is imported
from the parent ``arch_specs_generated`` module (generated from ``arch_specs.json``).
This file provides:
- JSON loading helpers
- Tile constraints (per-arch rules that reject invalid tiles)
- Feature compatibility rules (pipeline × feature flag interactions)
- Receipt filters and profiles (deployment-specific kernel subsets)
- Config validation for the AOT codegen path
"""
import json
import sys
from dataclasses import dataclass, field
from enum import IntEnum
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Tuple
# Ensure this directory and parent codegen/ are on sys.path for sibling imports
_THIS_DIR = Path(__file__).resolve().parent
_CODEGEN_DIR = _THIS_DIR.parent
sys.path.insert(0, str(_THIS_DIR))
sys.path.insert(0, str(_CODEGEN_DIR))
from symbol_map import ( # noqa: E402
BWD_DTYPE_MAP,
FWD_DTYPE_MAP,
canonical_bias,
canonical_mask,
canonical_qscale,
)
# Import shared hardware data from parent arch_specs_generated (generated from
# arch_specs.json by generate_arch_specs.py). Falls back to inline defaults if
# the generated module is unavailable (e.g. in standalone testing).
try:
from arch_specs_generated import ELEMENT_SIZE_MAP as _PARENT_ELEMENT_SIZES # noqa: E402
except ImportError:
_PARENT_ELEMENT_SIZES = {
"fp16": 2,
"bf16": 2,
"fp32": 4,
"fp64": 8,
"fp8": 1,
"bf8": 1,
"int8": 1,
"int4": 0.5,
"pk_fp4": 0.5,
"int32": 4,
}
# =============================================================================
# JSON data loading
# =============================================================================
_FMHA_SPECS_PATH = _THIS_DIR / "fmha_arch_specs.json"
def _load_fmha_specs() -> dict:
"""Load fmha_arch_specs.json (cached after first call)."""
if not hasattr(_load_fmha_specs, "_cache"):
with open(_FMHA_SPECS_PATH) as f:
_load_fmha_specs._cache = json.load(f)
return _load_fmha_specs._cache
def _build_element_sizes() -> Dict[str, int]:
"""Merge parent element sizes with FMHA-specific composite dtypes."""
base = {k: int(v) for k, v in _PARENT_ELEMENT_SIZES.items()}
base.update(_load_fmha_specs().get("fmha_element_sizes", {}))
return base
# =============================================================================
# 1. Architecture capabilities (loaded from fmha_arch_specs.json)
# =============================================================================
def _build_arch_dtypes() -> Dict[str, List[str]]:
"""Build ARCH_DTYPES from JSON architectures."""
return {
arch: info["supported_dtypes"]
for arch, info in _load_fmha_specs()["architectures"].items()
}
def _build_supported_hdims() -> Dict[str, List[Tuple[int, int]]]:
"""Build SUPPORTED_HDIMS from JSON, converting [q,v] lists to tuples."""
return {
dtype: [tuple(pair) for pair in pairs]
for dtype, pairs in _load_fmha_specs()["supported_hdims"].items()
if dtype != "_comment"
}
def _build_arch_metadata() -> Dict[str, dict]:
"""Build ARCH_METADATA from JSON architectures."""
return dict(_load_fmha_specs()["architectures"])
ARCH_DTYPES: Dict[str, List[str]] = _build_arch_dtypes()
SUPPORTED_HDIMS: Dict[str, List[Tuple[int, int]]] = _build_supported_hdims()
ARCH_METADATA: Dict[str, dict] = _build_arch_metadata()
# =============================================================================
# 2. Tile hardware parameters (loaded from fmha_arch_specs.json + parent arch_specs)
# =============================================================================
def _build_warp_classes() -> Dict[str, List[Tuple[int, int, int]]]:
"""Build WARP_CLASSES from JSON fmha_warp_tiles."""
return {
dtype: [tuple(w) for w in warps]
for dtype, warps in _load_fmha_specs()["fmha_warp_tiles"].items()
if dtype != "_comment"
}
def _build_lds_limits() -> Dict[str, int]:
"""Build LDS_LIMITS from JSON."""
return dict(_load_fmha_specs()["lds_limits"])
def _build_k0max_map() -> Dict[int, int]:
"""Build K0_MAX_SUBMAX_MAP from JSON (string keys → int keys)."""
return {
int(k): v for k, v in _load_fmha_specs()["k0max_map"].items() if k != "_comment"
}
_specs = _load_fmha_specs()
_tile_ranges = _specs["tile_sweep_ranges"]
LDS_LIMITS: Dict[str, int] = _build_lds_limits()
WARP_CLASSES: Dict[str, List[Tuple[int, int, int]]] = _build_warp_classes()
ELEMENT_SIZES: Dict[str, int] = _build_element_sizes()
VALID_BM0: List[int] = _tile_ranges["valid_bm0"]
VALID_BN0: List[int] = _tile_ranges["valid_bn0"]
VALID_BK0: List[int] = _tile_ranges["valid_bk0"]
K0_MAX_SUBMAX_MAP: Dict[int, int] = _build_k0max_map()
# =============================================================================
# 3. Tile constraints
# =============================================================================
def check_gfx9_tile_constraints(
dtype: str,
hdim_q: int,
hdim_v: int,
pipeline: str,
bm0: int,
bn0: int,
bk0: int,
) -> bool:
"""Gfx9 compatibility rules.
Source: fmha_fwd.py CompatibilityRuleFactoryGfx9.check_hdim_tile().
Applies to gfx90a, gfx942, gfx950 for pipelines in {qr, qr_async, qs}.
Note: CK factory is stricter (bm0==128 only for non-128 hdims); we allow
{64, 128, 192, 256} to let the tile engine explore more configurations.
"""
if dtype == "fp32":
return True
if pipeline not in ("qr", "qr_async", "qs"):
return True
if (hdim_q, hdim_v) == (128, 128) and bn0 != 128:
return False
if (hdim_q, hdim_v) == (128, 128) and pipeline == "qr_async" and bm0 != 128:
return False
if (hdim_q, hdim_v) != (128, 128) and bm0 not in (64, 128, 192, 256):
return False
if (hdim_q, hdim_v) == (128, 128) and pipeline != "qr_async" and bk0 == 64:
return False
return True
def check_gfx950_tile_constraints(
hdim_q: int,
hdim_v: int,
pipeline: str,
bm0: int,
bn0: int,
) -> bool:
"""Gfx950 trload/v3 constraints.
Source: fmha_fwd.py CompatibilityRuleFactoryGfx950.check_tile_pipeline().
Note: CK enforces biconditional (v3_tile ↔ v3_pipeline); we only enforce
v3_pipeline → v3_tile since non-v3 pipelines may still use bm0=256.
"""
if pipeline == "qr_async_trload":
if (hdim_q, hdim_v) == (128, 128) and bn0 == 128:
return False
if (hdim_q, hdim_v) not in [(64, 64), (128, 128)]:
return False
is_v3_tile = bm0 == 256
is_v3_pipeline = pipeline == "qr_async_trload_v3"
# v3 pipeline requires bm0=256; other pipelines also allow bm0=256
if is_v3_pipeline and not is_v3_tile:
return False
return True
def check_qr_mfma_insts(
arch: str,
hdim_q: int,
pipeline: str,
bn0: int,
bk0: int,
wn0: int,
wk0: int,
) -> bool:
"""NumMfmaInsts % 8 == 0 check.
Source: block_fmha_pipeline_qr_ks_vs.hpp static_assert at line ~490.
Full C++ formula: (kM0/WarpM)*(kN0/WarpN)*(kK0/WarpK) / (MWarp*NWarp).
We simplify to (bn0/wn0)*(bk0/wk0), omitting (bm0/wm0)/(rm0*rn0) which
equals 1 for all current fp16/bf16/fp32/fp8 tiles, or a power-of-2 factor
for mxfp8/mxfp4 that doesn't change the mod-8 result. This is conservative:
it can only reject tiles the full formula would also reject, never the reverse.
Only applies to qr pipeline + hdim_q==256 + CDNA (gfx9*).
"""
if pipeline != "qr" or hdim_q != 256:
return True
if not arch.startswith("gfx9"):
return True
num_mfma = (bn0 // wn0) * (bk0 // wk0)
if num_mfma % 8 != 0:
return False
return True
def tile_passes_all_constraints(
arch: str,
dtype: str,
hdim_q: int,
hdim_v: int,
pipeline: str,
bm0: int,
bn0: int,
bk0: int,
wm0: int,
wn0: int,
wk0: int,
) -> bool:
"""Master constraint check — returns True if the tile is valid."""
elem_size = ELEMENT_SIZES.get(dtype, 2)
lds_limit = LDS_LIMITS.get(pipeline, 65536)
# LDS capacity check (pipeline-dependent formula)
if pipeline in ("qr_async", "qr_async_trload", "qr_async_trload_v3"):
# Async pipeline: Q is in registers. LDS holds NumKVLdsBuffers (=3) copies of
# max(SingleKSize, SingleVSize). Derived from GetSmemSizeKV() in
# block_fmha_pipeline_qx_ks_vs_custom_policy.hpp.
#
# SingleVSize formula (MakeVLdsBlockDescriptor):
# Banks=32, PixelsPerRow = Banks*4/sizeof(dtype) = 32*4/elem_size
# kKPack = 16/elem_size (GetSmemKPackV)
# NPerRow = PixelsPerRow/kKPack
# SingleVSize = (bk1/kKPack) * (hdim_v/NPerRow) * (PixelsPerRow + kKPack)
# For bf16: PixelsPerRow=64, kKPack=8, NPerRow=8
# SingleVSize = (32/8)*(hdim_v/8)*(64+8) = 4*(hdim_v/8)*72 = 36*hdim_v
#
# SingleKSize formula (GetSingleSmemElementSpaceSize, async branch):
# KPack = 16/elem_size, KVector = alignment (gfx950: 16/elem_size = 8 for bf16)
# LanesPerK = bk0/KVector, LaneGroups = 64/LanesPerK
# NumIssues = bn0/(LaneGroups*NumWarps)
# SingleKSize = NumIssues*NumWarps*(64*KVector + KPack)
#
bk1 = 32 # kK1 in TileFmhaShape — design choice from fmha_fwd.py tile defs
num_warps = bm0 // wm0
# Banks: arch.hpp get_n_lds_banks() — 64 for gfx950, 32 for older
banks = 64 if arch == "gfx950" else 32
pixels_per_row = banks * 4 // elem_size # Banks * 4bytes / sizeof(dtype)
k_pack = 16 // elem_size # GetSmemKPackV: 16 / sizeof(dtype)
n_per_row = pixels_per_row // k_pack
single_v = (bk1 // k_pack) * (hdim_v // n_per_row) * (pixels_per_row + k_pack)
# KVector: GetAlignmentK in custom_policy.hpp — MaxLoadSizeInBytes / sizeof(dtype)
# gfx950 uses dwordx4 (16 bytes), older uses dword (4 bytes)
k_vector = 16 // elem_size if arch == "gfx950" else 4 // elem_size
lanes_per_k = bk0 // k_vector if k_vector > 0 else 1
lane_groups = 64 // lanes_per_k if lanes_per_k > 0 else 1 # WarpSize=64
num_issues = (
bn0 // (lane_groups * num_warps) if (lane_groups * num_warps) > 0 else 0
)
single_k = num_issues * num_warps * (64 * k_vector + k_pack)
single_buf_bytes = max(single_k, single_v) * elem_size
# NumPrefetchK = NumPrefetchV = 3 (async_default_policy.hpp)
num_kv_buffers = 3
# Q uses registers (QLoadOnce=true), so GetSmemSizeQ() = 0.
total_lds = single_buf_bytes * num_kv_buffers
# gfx950 HW LDS limit: arch.hpp get_smem_capacity() = 163840 (160 KiB)
if total_lds > 163840:
return False
else:
# Non-async (qr/qs): Q and K tiles share LDS simultaneously
if (bm0 * bk0 + bn0 * bk0) * elem_size > lds_limit:
return False
# bk0 range
if bk0 > hdim_q:
return False
# hdim_q divisibility (tile_fmha_shape.hpp:60)
if hdim_q % bk0 != 0:
return False
# Warp alignment
if bm0 % wm0 != 0 or bk0 % wk0 != 0 or bn0 % wn0 != 0:
return False
# MFMA inst count
if not check_qr_mfma_insts(arch, hdim_q, pipeline, bn0, bk0, wn0, wk0):
return False
# Async DMA distribution constraint (MakeKLdsStoreBlockDescriptor, custom_policy.hpp).
# NumIssues = kNPerBlock / (LaneGroups * NumWarps) must be a positive integer, where
# LaneGroups = WarpSize / LanesPerK = 64 / (bk0 / KVector).
# Equivalently: (bn0 * bk0) % (kBlockSize * KVector) == 0.
# KVector = MaxLoadSizeInBytes / sizeof(dtype): gfx950=16/2=8, older=4/2=2 for bf16.
if pipeline == "qr_async" and arch.startswith("gfx9"):
kvector = 16 // elem_size if arch == "gfx950" else 4 // elem_size
num_warps = bm0 // wm0
block_size = num_warps * 64 # WarpSize = 64
if (bn0 * bk0) % (block_size * kvector) != 0:
return False
# Arch constraints
if arch in ("gfx90a", "gfx942", "gfx950"):
if not check_gfx9_tile_constraints(
dtype, hdim_q, hdim_v, pipeline, bm0, bn0, bk0
):
return False
if arch == "gfx950":
if not check_gfx950_tile_constraints(hdim_q, hdim_v, pipeline, bm0, bn0):
return False
return True
# =============================================================================
# 4. Feature compatibility rules
# =============================================================================
# Supported mask, bias, and boolean values for feature products.
# These are the template enum values in CK's FMHA traits structs.
MASKS = ["no", "causal", "generic"]
BIASES = ["no", "bias", "alibi"]
BOOLS = ["t", "f"]
# Dtype groups matching CK's _DT_* classification in fmha_fwd.py factory classes.
DT_FP16_BF16 = {"fp16", "bf16"}
DT_FP8 = {"fp8bf16", "fp8", "bf8"}
DT_FP8FP32 = {"fp8fp32"}
DT_FP32 = {"fp32"}
def check_logits_bias(logits: str, bias: str) -> bool:
"""logits_soft_cap requires no bias.
Source: fmha_fwd.py CompatibilityRuleFactory.check_feature().
"""
return not (logits == "t" and bias != "no")
def check_group_mode_padding(mode: str, spad: str, skpad: str) -> bool:
"""Group mode requires spad=t and skpad=t.
Source: fmha_fwd.py CompatibilityRuleFactory.check_feature() +
block_fmha_pipeline static_asserts for padding.
"""
if mode == "group":
return spad == "t" and skpad == "t"
return True
# =============================================================================
# 5. Variant-specific tile tables (loaded from fmha_arch_specs.json)
# =============================================================================
def _build_bwd_tiles() -> Tuple[
Dict[Tuple[int, int], Tuple[int, ...]],
Dict[Tuple[int, int], List[Tuple[Tuple[int, ...], str, bool]]],
Dict[Tuple[int, int, int, str], dict],
]:
"""Build BWD tile tables from JSON."""
bwd = _load_fmha_specs()["bwd_tiles"]
# Main tiles: "hdimq_hdimv" -> 9-tuple
main = {}
for k, v in bwd["dq_dk_dv_fp16"].items():
hq, hv = map(int, k.split("_"))
main[(hq, hv)] = tuple(v)
# Extra tiles: "hdimq_hdimv" -> [(tile, tag, batch_only), ...]
extra = {}
for k, entries in bwd.get("dq_dk_dv_extra", {}).items():
hq, hv = map(int, k.split("_"))
extra[(hq, hv)] = [
(tuple(e["tile"]), e["tag"], e["batch_only"]) for e in entries
]
# Wave/warp lookup: "bm0_bn0_bk0_trload" -> {wave, warp_k1}
ww = {}
for k, v in _load_fmha_specs()["bwd_wave_warp"].items():
if k.startswith("_"):
continue
parts = k.split("_")
key = (int(parts[0]), int(parts[1]), int(parts[2]), parts[3])
ww[key] = {"wave": tuple(v["wave"]), "warp_k1": v["warp_k1"]}
return main, extra, ww
def _build_splitkv_hdims() -> Tuple[List[int], List[int]]:
"""Build SplitKV combine hdim lists from JSON."""
skv = _load_fmha_specs()["splitkv_combine"]
return skv["hdims_fp16"], skv["hdims_fp8"]
_bwd_main, _bwd_extra, _bwd_ww = _build_bwd_tiles()
_skv_fp16, _skv_fp8 = _build_splitkv_hdims()
SPLITKV_COMBINE_HDIMS_FP16: List[int] = _skv_fp16
SPLITKV_COMBINE_HDIMS_FP8: List[int] = _skv_fp8
BWD_DQ_DK_DV_TILES_FP16: Dict[Tuple[int, int], Tuple[int, ...]] = _bwd_main
BWD_DQ_DK_DV_EXTRA_TILES: Dict[
Tuple[int, int], List[Tuple[Tuple[int, ...], str, bool]]
] = _bwd_extra
BWD_DQ_WAVE_WARP: Dict[Tuple[int, int, int, str], dict] = _bwd_ww
_bwd_json = _load_fmha_specs()["bwd_tiles"]
BWD_EXTRA_PAD_COMBOS: List[Tuple[str, str]] = [
tuple(p) for p in _bwd_json["extra_pad_combos"]
]
BWD_SMALL_DROPOUTS: List[str] = _bwd_json["small_dropouts"]
BWD_DOT_DO_O_HDIMS: List[int] = _bwd_json["dot_do_o_hdims"]
BWD_CONVERT_DQ_HDIMS: List[int] = _bwd_json["convert_dq_hdims"]
BWD_CONVERT_DQ_TILE_GROUPS: Dict[int, int] = {
int(k): v for k, v in _bwd_json["convert_dq_tile_groups"].items()
}
BWD_DROPOUTS: List[str] = _bwd_json["dropouts"]
BWD_PAD_COMBOS: List[Tuple[str, str]] = [tuple(p) for p in _bwd_json["pad_combos"]]
# =============================================================================
# 6. Receipt filters
# =============================================================================
class Receipt(IntEnum):
"""Named receipt levels for deployment profiles.
These are deployment-specific filters, not derived from C++ constraints.
They control which kernel subsets are emitted for different integration
targets (PyTorch, AITER, Flash-Attention, etc.).
"""
CK_DEFAULT = 0
CK_EXTENDED = 1
FLASH_FWD = 2
FLASH_BWD = 3
PYTORCH = 4
AITER_BATCH = 100
AITER_GROUP = 200
AITER_BWD_BATCH = 300
AITER_BWD_GROUP = 400
AITER_CPP = 600
FP32_ALL = 800
FP32_MIN = 801
FP8_TEST = 888
RECEIPT_FILTERS: Dict[int, Callable[[str, object], bool]] = {
0: lambda dtype, spec: dtype != "fp32",
2: lambda dtype, spec: (
dtype in ("fp16", "bf16")
and getattr(spec, "bias", "no") in ("no", "alibi")
and getattr(spec, "qscale", "no") == "no"
and getattr(spec, "skip", "f") == "f"
and getattr(spec, "sink", "f") == "f"
),
4: lambda dtype, spec: (
dtype in ("fp16", "bf16")
and getattr(spec, "bias", "no") in ("no", "bias")
and getattr(spec, "qscale", "no") == "no"
and getattr(spec, "skip", "f") == "f"
and getattr(spec, "logits", "f") == "f"
),
100: lambda dtype, spec: dtype in ("fp16", "bf16", "fp8bf16"),
200: lambda dtype, spec: dtype in ("fp16", "bf16", "fp8bf16"),
600: lambda dtype, spec: dtype in ("fp16", "bf16", "fp8bf16"),
888: lambda dtype, spec: dtype in ("fp8bf16", "fp8fp32"),
800: lambda dtype, spec: (
dtype == "fp32"
and getattr(spec, "skip", "f") == "f"
and getattr(spec, "logits", "f") == "f"
),
}
def receipt_filter(receipt: int, dtype: str, spec) -> bool:
"""Apply receipt-level filter. Returns True if the kernel should be kept."""
fn = RECEIPT_FILTERS.get(receipt)
if fn is None:
return dtype != "fp32"
return fn(dtype, spec)
# =============================================================================
# 7. Profiles
# =============================================================================
PROFILE_ALIASES: Dict[str, str] = {str(r.value): r.name.lower() for r in Receipt}
@dataclass(frozen=True)
class FmhaProfile:
name: str
predicate: Callable[[dict], bool]
def allows(self, config: dict) -> bool:
return self.predicate(config)
def _dtype_is(config: dict, allowed: Iterable[str]) -> bool:
return config["signature"]["data_type"] in set(allowed)
def _mode_is(config: dict, allowed: Iterable[str]) -> bool:
return config["signature"]["mode"] in set(allowed)
def _family_is(config: dict, allowed: Iterable[str]) -> bool:
return config["signature"]["family"] in set(allowed)
def _common_row_major_filter(config: dict) -> bool:
return config["signature"]["vlayout"] == "r"
def _bias_is(config: dict, allowed: Iterable[str]) -> bool:
return canonical_bias(config["signature"]["bias"]) in set(allowed)
def _qscale_is(config: dict, allowed: Iterable[str]) -> bool:
return canonical_qscale(config["signature"]["qscale"]) in set(allowed)
def _no_skip_or_logits(config: dict) -> bool:
return (not config["signature"]["skip_min_seqlen_q"]) and (
not config["signature"]["logits"]
)
PROFILES: Dict[str, FmhaProfile] = {
"ck_default": FmhaProfile(
"ck_default", lambda c: c["signature"]["data_type"] != "fp32"
),
"ck_extended": FmhaProfile(
"ck_extended", lambda c: c["signature"]["data_type"] != "fp32"
),
"flash_fwd": FmhaProfile(
"flash_fwd",
lambda c: (
_family_is(c, {"fwd", "fwd_splitkv", "fwd_appendkv", "fwd_pagedkv"})
and _dtype_is(c, {"fp16", "bf16"})
and _common_row_major_filter(c)
and _bias_is(c, {"no", "alibi"})
and _qscale_is(c, {"no"})
and not c["signature"]["skip_min_seqlen_q"]
),
),
"flash_bwd": FmhaProfile(
"flash_bwd",
lambda c: (
_family_is(c, {"bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"})
and _dtype_is(c, {"fp16", "bf16"})
),
),
"pytorch": FmhaProfile(
"pytorch",
lambda c: (
_dtype_is(c, {"fp16", "bf16"})
and _common_row_major_filter(c)
and _bias_is(c, {"no", "bias"})
and _qscale_is(c, {"no"})
and _no_skip_or_logits(c)
and not c["signature"].get("sink", False)
),
),
"aiter_batch": FmhaProfile(
"aiter_batch",
lambda c: (
_dtype_is(c, {"fp16", "bf16", "fp8bf16"})
and _mode_is(c, {"batch"})
and _common_row_major_filter(c)
and (
c["signature"]["data_type"] != "fp8bf16"
or c["signature"]["hdim_q"] in {128, 192}
)
),
),
"aiter_group": FmhaProfile(
"aiter_group",
lambda c: (
_dtype_is(c, {"fp16", "bf16", "fp8bf16"})
and _mode_is(c, {"group"})
and _common_row_major_filter(c)
),
),
"aiter_bwd_batch": FmhaProfile(
"aiter_bwd_batch",
lambda c: (
_family_is(c, {"bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"})
and _dtype_is(c, {"fp16", "bf16"})
and _mode_is(c, {"batch"})
),
),
"aiter_bwd_group": FmhaProfile(
"aiter_bwd_group",
lambda c: (
_family_is(c, {"bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"})
and _dtype_is(c, {"fp16", "bf16"})
and _mode_is(c, {"group"})
),
),
"aiter_cpp": FmhaProfile(
"aiter_cpp",
lambda c: (
_dtype_is(c, {"fp16", "bf16", "fp8bf16"})
and _common_row_major_filter(c)
and (
c["signature"]["data_type"] != "fp8bf16"
or c["signature"]["hdim_q"] in {128, 192}
)
),
),
"fp32_all": FmhaProfile(
"fp32_all", lambda c: _dtype_is(c, {"fp32"}) and _no_skip_or_logits(c)
),
"fp32_min": FmhaProfile(
"fp32_min",
lambda c: (
_dtype_is(c, {"fp32"})
and _mode_is(c, {"batch"})
and c["signature"]["hdim_q"] in {48, 128}
and c["signature"]["hdim_v"] in {48, 128}
and canonical_bias(c["signature"]["bias"]) == "no"
and not c["signature"]["lse"]
and not c["signature"]["dropout"]
and canonical_qscale(c["signature"]["qscale"]) == "no"
),
),
"fp8_test": FmhaProfile(
"fp8_test",
lambda c: (
_dtype_is(c, {"fp8bf16", "fp8fp32"})
and c["signature"]["hdim_q"] in {128, 192}
and _common_row_major_filter(c)
),
),
"all": FmhaProfile("all", lambda _: True),
}
def normalize_profile(
profile: Optional[str] = None, receipt: Optional[str] = None
) -> str:
if profile:
return PROFILE_ALIASES.get(str(profile), str(profile))
if receipt is not None:
return PROFILE_ALIASES.get(str(receipt), str(receipt))
return "ck_default"
def get_profile(
profile: Optional[str] = None, receipt: Optional[str] = None
) -> FmhaProfile:
normalized = normalize_profile(profile=profile, receipt=receipt)
if normalized not in PROFILES:
raise KeyError(f"Unknown FMHA profile: {normalized}")
return PROFILES[normalized]
def profile_allows(
config: dict, profile: Optional[str] = None, receipt: Optional[str] = None
) -> bool:
return get_profile(profile=profile, receipt=receipt).allows(config)
# =============================================================================
# 8. Validation helpers (for unified_fmha_codegen)
# =============================================================================
_DEFAULTS: dict = _load_fmha_specs()["defaults"]
_GLOBAL_RULES: dict = _load_fmha_specs()["global_rules"]
def load_arch_specs() -> dict:
"""Return arch_specs dict compatible with unified_fmha_codegen.
Combines FMHA-specific architecture data from fmha_arch_specs.json with
defaults, global rules, and splitkv combine params.
"""
specs = _load_fmha_specs()
return {
"architectures": ARCH_METADATA,
"defaults": _DEFAULTS,
"global_rules": _GLOBAL_RULES,
"splitkv_combine": specs["splitkv_combine"],
}
# =============================================================================
# 9. Config validation (for unified_fmha_codegen)
# =============================================================================
@dataclass
class ValidationResult:
valid: bool = True
errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
def add_error(self, msg: str):
self.valid = False
self.errors.append(msg)
def add_warning(self, msg: str):
self.warnings.append(msg)
def validate_config(
config: dict, arch_specs: Optional[dict] = None
) -> "ValidationResult":
"""Validate an FMHA kernel config against all rules."""
arch_specs = arch_specs or load_arch_specs()
result = ValidationResult()
sig = config["signature"]
alg = config["algorithm"]
arch = config["arch"]
architectures = arch_specs.get("architectures", ARCH_METADATA)
if arch not in architectures:
result.add_error(f"Unsupported FMHA target architecture: {arch}")
return result
arch_info = architectures[arch]
global_rules = arch_specs.get("global_rules", _GLOBAL_RULES)
dtype = sig["data_type"]
family = sig["family"]
pipeline = alg["pipeline"]
canonical_mask(sig["mask"])
bias = canonical_bias(sig["bias"])
# Family validation
supported_families = {
"fwd",
"fwd_pagedkv",
"fwd_splitkv",
"fwd_splitkv_combine",
"fwd_appendkv",
"batch_prefill",
"bwd_dot_do_o",
"bwd_dq_dk_dv",
"bwd_convert_dq",
}
if family not in supported_families:
result.add_error(f"Unsupported FMHA family: {family}")
# Dtype validation
supported_dtypes = set(arch_info["supported_dtypes"])
if dtype not in supported_dtypes:
result.add_error(f"dtype {dtype} is not supported on {arch}")
if family.startswith("bwd") and dtype not in BWD_DTYPE_MAP:
result.add_error(
f"Backward family {family} only supports {sorted(BWD_DTYPE_MAP)}"
)
if (
family.startswith("fwd")
and not family.startswith("fwd_append")
and dtype not in FWD_DTYPE_MAP
):
result.add_error(f"Forward family {family} does not recognize dtype {dtype}")
# Pipeline validation
if (
family != "fwd_splitkv_combine"
and pipeline not in arch_info["supported_pipelines"]
):
result.add_error(f"pipeline {pipeline} is not supported on {arch}")
if pipeline in {"v3", "qr_async_trload_v3"} and not arch_info.get(
"supports_v3", False
):
result.add_warning(f"v3 pipeline on {arch} requires supports_v3 in arch specs")
if pipeline == "qr_async_trload" and not arch_info.get("supports_trload", False):
result.add_error("qr_async_trload requires a trload-capable architecture")
# Global rules
hdim_q = sig["hdim_q"]
hdim_v = sig["hdim_v"]
divisor = global_rules.get("hdim_divisible_by", 8)
if hdim_q % divisor != 0 or hdim_v % divisor != 0:
result.add_error(f"Head dimensions must be multiples of {divisor}")
if global_rules.get("hdim_192_128_no_bias_dropout"):
if (
hdim_q == 192
and hdim_v == 128
and (bias != "no" or sig.get("dropout", False))
):
result.add_warning(
"hdim (192,128) with bias/dropout has limited tile support"
)
if global_rules.get("logits_requires_no_bias"):
if bias != "no" and sig.get("logits", False):
result.add_error("logits_soft_cap cannot be combined with bias")
if pipeline in {"qr_async_trload", "v3", "qr_async_trload_v3"} and (
hdim_q != hdim_v or hdim_q not in {64, 128}
):
result.add_error(f"{pipeline} only supports symmetric head dims 64 or 128")
# Tile validation
tile = alg["tile"]
expected_tile_len = 9 if family == "bwd_dq_dk_dv" else 6
if len(tile) != expected_tile_len or len(alg["wave"]) != 9 or len(alg["warp"]) != 9:
result.add_error(
f"tile/wave/warp must have {expected_tile_len}/9/9 elements for {family}"
)
# MFMA instruction count check for qr/h256/CDNA
_1d_families = {"bwd_dot_do_o", "bwd_convert_dq"}
if (
pipeline == "qr"
and hdim_q == 256
and family not in _1d_families
and arch_info.get("family", "").startswith("cdna")
and len(tile) >= 3
and len(alg["wave"]) >= 2
and len(alg["warp"]) >= 3
):
wm, wn, wk = alg["warp"][0], alg["warp"][1], alg["warp"][2]
gm, gn = alg["wave"][0], alg["wave"][1]
if wm > 0 and wn > 0 and wk > 0 and gm > 0 and gn > 0:
num_mfma = (tile[0] // wm) * (tile[1] // wn) * (tile[2] // wk) // (gm * gn)
if num_mfma % 8 != 0:
result.add_error(
f"NumMfmaInsts={num_mfma} must be divisible by 8 for qr/h256/CDNA"
)
if alg["block_per_cu"] <= 0 and alg["block_per_cu"] != -1:
result.add_error("block_per_cu must be positive or -1 (auto)")
if alg["num_wave_groups"] <= 0:
result.add_error("num_wave_groups must be positive")
# --- Family-specific rules ---
if family == "batch_prefill":
if sig.get("vlayout", "r") != "r":
result.add_error("batch_prefill only supports row-major V layout")
if not sig.get("paged_kv", False):
result.add_error("batch_prefill requires paged_kv=true")
ps = sig.get("page_size", 0)
if ps <= 0 or (ps & (ps - 1)) != 0:
result.add_error("batch_prefill page_size must be a positive power of two")
if sig.get("mode", "batch") != "group":
result.add_error("batch_prefill requires group mode")
if pipeline != "qr_async":
result.add_error("batch_prefill currently uses qr_async pipeline")
if family == "fwd_appendkv":
if sig.get("mode", "batch") != "batch":
result.add_error("fwd_appendkv uses batch-mode public API surface")
if pipeline != "appendkv":
result.add_error("fwd_appendkv must use appendkv pipeline")
if sig.get("vlayout", "r") != "r":
result.add_error("fwd_appendkv currently only supports row-major V")
if family == "fwd_splitkv_combine":
if sig.get("mode", "batch") not in {"batch", "group"}:
result.add_error("fwd_splitkv_combine requires batch or group mode")
combine_bn1 = arch_specs.get("splitkv_combine", {}).get("combine_bn1", 32)
if len(tile) > 3 and tile[3] != combine_bn1:
result.add_error(f"fwd_splitkv_combine requires bn1={combine_bn1}")
if len(tile) > 3 and (hdim_v < tile[3] or hdim_v % tile[3] != 0):
result.add_error("fwd_splitkv_combine requires hdim_v divisible by bn1")
if family == "fwd_pagedkv":
if pipeline != "qr_pagedkv":
result.add_error("fwd_pagedkv must use qr_pagedkv pipeline")
if not sig.get("paged_kv", False):
result.add_error("fwd_pagedkv requires paged_kv=true")
if sig.get("vlayout", "r") != "r":
result.add_error("fwd_pagedkv currently only supports row-major V")
if family == "fwd_splitkv":
if pipeline not in {"qr", "qr_nwarp_sshuffle"}:
result.add_error("fwd_splitkv must use qr or qr_nwarp_sshuffle pipeline")
if sig.get("vlayout", "r") != "r":
result.add_error("fwd_splitkv currently only supports row-major V")
if family == "fwd" and sig.get("vlayout", "r") != "r":
result.add_warning("dispatcher forward examples currently assume row-major V")
return result