Files
composable_kernel/dispatcher/codegen/fmha/instance_gen.py
Vidyasagar Ananthan 86591de476 [rocm-libraries] ROCm/rocm-libraries#5260 (commit a1834d2)
[CK] [CK_Tile] Add FMHA scaffolding to CK kernel dispatcher
 (#5260)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

The CK Tile dispatcher currently supports GEMM and Grouped Convolution
but has no support for Fused Multi-Head Attention (FMHA). The
example/ck_tile/01_fmha folder contains a comprehensive FMHA
implementation with forward, backward, split-KV, paged-KV, append-KV,
and batch-prefill kernels across multiple GPU architectures — but there
is no unified dispatch layer for it. This PR ports the FMHA stack into
the dispatcher, following the same architectural patterns established by
GEMM and Grouped Convolution, enabling runtime kernel selection, JIT
compilation from Python, and a declarative C++ example flow. Autotuning
heuristics to follow.

## Technical Details

This PR adds FMHA scaffolding to the CK dispatcher framework, mirroring
GEMM's layered architecture. Seven new C++ runtime headers provide type
definitions (coexisting with upstream headers via __has_include,
requiring zero modifications to example/ck_tile/01_fmha/), a problem
builder with 18+ setters, Signature + Algorithm kernel key matching, a
virtual kernel instance, a DECL_FMHA_KERNEL_SET macro with wildcard
support and named tile/wave/warp setters, arch-aware registry with JSON
export, and a dispatcher with seqtune-aware selection, configurable
timing, and multi-stage execution plans for split-KV (two-stage) and
backward (three-stage). The codegen pipeline is driven by a
fmha_arch_specs.json capturing per-arch tile tables and pipeline
constraints for five architectures (gfx90a/942/950/1100/1201), migrated
from hardcoded logic in 01_fmha/codegen/, with supporting modules for
C++ symbol mappings, validation rules, and named receipt profiles
(ck_default, flash, pytorch, aiter, fp32, fp8). Python integration
(fmha_utils.py) mirrors the C++ layer with JIT compilation, parallel
multi-kernel builds, HIP memory management via ctypes, tolerance-based
validation, and a NumPy CPU reference with GQA support. Twenty-seven C++
and thirty-two Python examples cover the full feature surface — forward,
split-KV, masks, bias, dropout, GQA, backward, append-KV, batch prefill,
fp8, logits soft cap, sink tokens, and parameter sweeps — all
JIT-compiled on the fly.

## Test Plan

Seven test files cover the runtime types, codegen, and end-to-end
correctness. C++ unit tests validate the problem builder, dispatcher
planning (single-stage for forward/paged-KV/append-KV; multi-stage for
split-KV and backward), registry operations, and the kernel-set
declaration macro. Python unit tests verify codegen emission, profile
filtering, and 15 validation rules for masks, hdim constraints, and
pipeline requirements. GPU execution validation in 01_basic_fmha
--validate reports zero errors across 65,536 elements with max absolute
error of 7.29e-05. A gold-standard parity suite (test_fmha_parity.py)
runs 14 configurations through both the upstream tile_example_fmha_fwd
and the dispatcher, comparing exit codes to confirm behavioral parity —
all 14 match.

## Test Result

The C++ smoke test builds and passes all 9 compiled examples, and a
Python JIT sweep (29_sweep_seqlen.py) passes 7/7 configurations reaching
up to 375 TFLOPS at seqlen 2048.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-17 07:30:33 +00:00

2693 lines
86 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
FMHA instance generation — generates tile configs and expands kernel instances.
Three layers:
1. Tile generation — enumerate valid (bm0, bn0, bk0, warp) combinations
2. Feature enumeration — enumerate valid (mask, bias, lse, dropout, padding) combinations
3. Instance expansion — cross-product tiles × features × modes → kernel configs
All hardware facts and constraints come from specs.py.
All symbol mappings come from symbol_map.py.
Usage:
python -m fmha.instance_gen configs/receipt0_fwd.json --arch gfx950
python -m fmha.instance_gen configs/fwd_ci.json --arch gfx950 --list
"""
import argparse
import itertools
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
_THIS_DIR = Path(__file__).resolve().parent
_DISPATCHER_ROOT = _THIS_DIR.parents[1]
sys.path.insert(0, str(_DISPATCHER_ROOT / "python"))
sys.path.insert(0, str(_THIS_DIR))
from validation import ( # noqa: E402
ARCH_DTYPES,
BIASES,
BOOLS,
BWD_CONVERT_DQ_HDIMS,
BWD_CONVERT_DQ_TILE_GROUPS,
BWD_DOT_DO_O_HDIMS,
BWD_DQ_DK_DV_EXTRA_TILES,
BWD_DQ_DK_DV_TILES_FP16,
BWD_DQ_WAVE_WARP,
BWD_DROPOUTS,
BWD_EXTRA_PAD_COMBOS,
BWD_PAD_COMBOS,
BWD_SMALL_DROPOUTS,
DT_FP16_BF16,
DT_FP32,
DT_FP8,
DT_FP8FP32,
ELEMENT_SIZES,
K0_MAX_SUBMAX_MAP,
LDS_LIMITS,
MASKS,
SPLITKV_COMBINE_HDIMS_FP16,
SPLITKV_COMBINE_HDIMS_FP8,
SUPPORTED_HDIMS,
VALID_BK0,
VALID_BM0,
VALID_BN0,
WARP_CLASSES,
check_gfx9_tile_constraints,
check_gfx950_tile_constraints,
check_group_mode_padding,
check_logits_bias,
check_qr_mfma_insts,
receipt_filter,
tile_passes_all_constraints,
)
from fmha_utils import FmhaKernelConfig # noqa: E402 (from dispatcher/python/)
# =============================================================================
# Tile configuration dataclass
# =============================================================================
@dataclass(frozen=True)
class FmhaTileConfig:
"""Complete FMHA tile configuration with all derived parameters.
Field naming follows CK's TileFmhaShape template parameters:
- bm0/bn0/bk0: block tile for Gemm0 (Q*K^T), from sequence<M,N,K,N1,K1,K0Max>
- bn1/bk1: block tile for Gemm1 (P*V)
- bk0max: kSubQKHeaddim from tile_fmha_shape.hpp
- rm0: wave repeat in M direction = bm0/wm0
- wm0/wn0/wk0: MFMA/WMMA warp tile from warp_gemm_dispatcher.hpp
"""
bm0: int
bn0: int
bk0: int
bn1: int # = hdim_v
bk1: int # = 32 typically
bk0max: int # = K0_MAX_SUBMAX_MAP[hdim_q]
rm0: int # wave repeat = bm0/wm0
wm0: int
wn0: int
wk0: int
wm1: int
wn1: int
wk1: int
rn0: int = 1
rk0: int = 1
rm1: int = 1
rn1: int = 1
rk1: int = 1
@property
def tile_6(self) -> Tuple[int, int, int, int, int, int]:
return (self.bm0, self.bn0, self.bk0, self.bn1, self.bk1, self.bk0max)
# =============================================================================
# BK1 derivation
# =============================================================================
def derive_bk1(bm0: int, bn0: int, bk0: int, hdim_q: int, hdim_v: int) -> int:
"""Derive bk1 from tile config for fp16/bf16/fp32.
Source: fmha_fwd.py FmhaFwdTileSize definitions — bk1 (element 4) is
always 32 except for three specific configs where it's 16.
These special cases come from the CK example's hand-tuned tile tables.
"""
if (bm0, bn0, bk0, hdim_q) in (
(128, 64, 32, 128),
(32, 128, 32, 128),
(32, 128, 16, 48),
):
return 16
return 32
def derive_bk1_fp8(bm0: int, bn0: int, bk0: int, hdim_q: int, hdim_v: int) -> int:
"""Derive bk1 for fp8 dtypes.
Source: fmha_fwd.py FP8 tile definitions — bk1 always equals bk0.
"""
return bk0
# =============================================================================
# Tile generation
# =============================================================================
def generate_fwd_tiles(
arch: str,
dtype: str,
hdim_q: int,
hdim_v: int,
pipeline: str = "qr_async",
apply_constraints: bool = True,
) -> List[FmhaTileConfig]:
"""Generate fwd tile configurations.
apply_constraints=True (default): filter through tile_passes_all_constraints
— used by rules-mode benchmarking and codegen.
apply_constraints=False: only basic sanity (warp alignment, bk0<=hdim_q)
— used by exhaustive-mode benchmarking to find tiles the C++ compiler
accepts that our rules might reject.
"""
warp_classes = WARP_CLASSES.get(dtype, [(32, 32, 16)])
bk0max = K0_MAX_SUBMAX_MAP.get(hdim_q, hdim_q)
is_fp8 = "fp8" in dtype or dtype in ("bf8", "mxfp8", "mxfp4")
tiles: List[FmhaTileConfig] = []
for bm0 in VALID_BM0:
for bn0 in VALID_BN0:
for bk0 in VALID_BK0:
if bk0 > hdim_q:
continue
for wm0, wn0, wk0 in warp_classes:
if bm0 % wm0 != 0 or bn0 % wn0 != 0 or bk0 % wk0 != 0:
continue
if apply_constraints and not tile_passes_all_constraints(
arch,
dtype,
hdim_q,
hdim_v,
pipeline,
bm0,
bn0,
bk0,
wm0,
wn0,
wk0,
):
continue
rm0 = bm0 // wm0
bk1 = (
derive_bk1_fp8(bm0, bn0, bk0, hdim_q, hdim_v)
if is_fp8
else derive_bk1(bm0, bn0, bk0, hdim_q, hdim_v)
)
tiles.append(
FmhaTileConfig(
bm0=bm0,
bn0=bn0,
bk0=bk0,
bn1=hdim_v,
bk1=bk1,
bk0max=bk0max,
rm0=rm0,
rm1=rm0,
wm0=wm0,
wn0=wn0,
wk0=wk0,
wm1=wm0,
wn1=wn0,
wk1=wk0,
)
)
return tiles
def generate_splitkv_tiles(
arch: str,
dtype: str,
hdim_q: int,
hdim_v: int,
apply_constraints: bool = True,
) -> List[FmhaTileConfig]:
"""Generate splitkv tiles.
Uses fixed warp class per dtype: (16,16,16) for fp16/bf16/fp32,
(32,32,32) for fp8. These match the warp tiles used in the CK example's
splitkv tile definitions (fmha_fwd.py KernelComponentFactory*.get_splitkv_tiles()).
LDS limit: 64 KiB (non-async pipeline, arch.hpp get_smem_capacity for non-gfx950).
apply_constraints=False skips LDS check (for exhaustive mode).
"""
bk0max = K0_MAX_SUBMAX_MAP.get(hdim_q, hdim_q)
is_fp8 = "fp8" in dtype or dtype == "bf8"
wm0, wn0, wk0 = (32, 32, 32) if is_fp8 else (16, 16, 16)
tiles: List[FmhaTileConfig] = []
for bm0 in VALID_BM0:
for bn0 in VALID_BN0:
for bk0 in VALID_BK0:
if bk0 > hdim_q:
continue
if bm0 % wm0 != 0 or bk0 % wk0 != 0 or bn0 % wn0 != 0:
continue
if apply_constraints:
elem_size = ELEMENT_SIZES.get(dtype, 2)
lds_limit = LDS_LIMITS.get("qr", 65536)
if (bm0 * bk0 + bn0 * bk0) * elem_size > lds_limit:
continue
rm0 = bm0 // wm0
bk1 = bk0 if is_fp8 else 32
tiles.append(
FmhaTileConfig(
bm0=bm0,
bn0=bn0,
bk0=bk0,
bn1=hdim_v,
bk1=bk1,
bk0max=bk0max,
rm0=rm0,
rm1=rm0,
wm0=wm0,
wn0=wn0,
wk0=wk0,
wm1=wm0,
wn1=wn0,
wk1=wk0,
)
)
return tiles
def generate_pagedkv_tiles(
arch: str,
dtype: str,
hdim_q: int,
hdim_v: int,
apply_constraints: bool = True,
) -> List[FmhaTileConfig]:
"""PagedKV uses same tile rules as splitkv."""
return generate_splitkv_tiles(arch, dtype, hdim_q, hdim_v, apply_constraints)
def generate_bwd_tiles(
arch: str,
dtype: str,
hdim_q: int,
hdim_v: int,
apply_constraints: bool = True,
) -> List[FmhaTileConfig]:
"""Generate BWD tile configurations.
apply_constraints=False skips LDS check (for exhaustive mode).
"""
warp_classes = WARP_CLASSES.get(dtype, [(32, 32, 16)])
bk0max = K0_MAX_SUBMAX_MAP.get(hdim_q, hdim_q)
is_fp8 = "fp8" in dtype or dtype in ("bf8", "mxfp8", "mxfp4")
tiles: List[FmhaTileConfig] = []
for bm0 in VALID_BM0:
for bn0 in VALID_BN0:
for bk0 in VALID_BK0:
if bk0 > hdim_q:
continue
for wm0, wn0, wk0 in warp_classes:
if bm0 % wm0 != 0 or bk0 % wk0 != 0 or bn0 % wn0 != 0:
continue
if apply_constraints:
elem_size = ELEMENT_SIZES.get(dtype, 2)
lds_limit = LDS_LIMITS.get("qs", 65536)
if (bm0 * bk0 + bn0 * bk0) * elem_size > lds_limit:
continue
rm0 = bm0 // wm0
bk1 = bk0 if is_fp8 else 32
tiles.append(
FmhaTileConfig(
bm0=bm0,
bn0=bn0,
bk0=bk0,
bn1=hdim_v,
bk1=bk1,
bk0max=bk0max,
rm0=rm0,
rm1=rm0,
wm0=wm0,
wn0=wn0,
wk0=wk0,
wm1=wm0,
wn1=wn0,
wk1=wk0,
)
)
return tiles
def validate_tile(
tile: "FmhaTileConfig",
arch: str,
dtype: str,
hdim_q: int,
hdim_v: int,
pipeline: str = "qr_async",
) -> bool:
"""Validate a single tile configuration against all constraints."""
return tile_passes_all_constraints(
arch,
dtype,
hdim_q,
hdim_v,
pipeline,
tile.bm0,
tile.bn0,
tile.bk0,
tile.wm0,
tile.wn0,
tile.wk0,
)
# =============================================================================
# Pipeline spec dataclasses
# =============================================================================
@dataclass(frozen=True)
class PipelineSpec:
"""One FWD pipeline variant with its feature flags and padding."""
tag: str
mask: str
bias: str
lse: str
dropout: str
logits: str
skip: str
sink: str
qscale: str = "no"
spad: str = "f"
skpad: str = "f"
dpad: str = "f"
dvpad: str = "f"
@dataclass(frozen=True)
class SplitKVPipelineSpec:
"""Split-KV main kernel pipeline variant."""
tag: str
mask: str
bias: str
logits: str
sink: str
pagedkv: str = "f"
squant: str = "f"
spad: str = "f"
skpad: str = "f"
dpad: str = "f"
dvpad: str = "f"
lse: str = "t"
@dataclass(frozen=True)
class SplitKVCombineSpec:
"""Split-KV combine kernel pipeline variant."""
spad: str
dvpad: str
lse: str
squant: str = "f"
@dataclass(frozen=True)
class AppendKVPipelineSpec:
"""Append-KV pipeline variant."""
rope: str = "none"
pagedkv: str = "f"
spad: str = "t"
skpad: str = "t"
dpad: str = "t"
dvpad: str = "t"
@dataclass(frozen=True)
class BatchPrefillPipelineSpec:
"""Batch prefill pipeline variant."""
mask: str
bias: str
logits: str
sink: str
lse: str = "f"
dropout: str = "f"
skip: str = "f"
qscale: str = "no"
page_size: int = 0
kv_memory_layout: str = "vectorized"
kv_lookup_table: str = "sglang"
spad: str = "t"
skpad: str = "t"
dpad: str = "t"
dvpad: str = "t"
@dataclass(frozen=True)
class BwdPipelineSpec:
"""BWD pipeline variant."""
family: str
mask: str = "no"
bias: str = "no"
dbias: str = "f"
dropout: str = "f"
deterministic: str = "f"
spad: str = "t"
skpad: str = "t"
dpad: str = "t"
dvpad: str = "t"
# =============================================================================
# Feature-product generators
# =============================================================================
def _fwd_specs_fp16bf16(
hdim: int,
hdim_v: int,
receipt: int,
) -> List[PipelineSpec]:
"""Pipeline specs for fp16/bf16 on gfx9/gfx950.
Source: fmha_fwd.py KernelComponentFactoryGfx9.get_pipelines() —
hdim=256 always uses 'qr' (non-async, since bk0 can equal 256).
Non-256 hdims use 'qr_async' for non-bias configs (async DMA),
'qr' for bias configs (bias requires Q in LDS).
Receipt=1 (ck_extended) adds extra 'qr' variants for non-bias.
"""
specs: List[PipelineSpec] = []
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
BOOLS,
MASKS,
BIASES,
BOOLS,
BOOLS,
BOOLS,
BOOLS,
):
if hdim == 256 and hdim_v == 256:
specs.append(
PipelineSpec(
"qr",
mask,
bias,
lse,
dropout,
logits,
skip,
sink,
spad="f",
skpad="f",
dpad="f",
dvpad="f",
)
)
specs.append(
PipelineSpec(
"qr",
mask,
bias,
lse,
dropout,
logits,
skip,
sink,
spad="t",
skpad="t",
dpad="f",
dvpad="f",
)
)
specs.append(
PipelineSpec(
"qr",
mask,
bias,
lse,
dropout,
logits,
skip,
sink,
spad="t",
skpad="t",
dpad="t",
dvpad="t",
)
)
else:
if bias == "bias":
specs.append(
PipelineSpec(
"qr",
mask,
bias,
lse,
dropout,
logits,
skip,
sink,
spad="f",
skpad="f",
dpad="f",
dvpad="f",
)
)
specs.append(
PipelineSpec(
"qr",
mask,
bias,
lse,
dropout,
logits,
skip,
sink,
spad="t",
skpad="t",
dpad="t",
dvpad="t",
)
)
else:
specs.append(
PipelineSpec(
"qr_async",
mask,
bias,
lse,
dropout,
logits,
skip,
sink,
spad="t",
skpad="f",
dpad="t",
dvpad="t",
)
)
specs.append(
PipelineSpec(
"qr_async",
mask,
bias,
lse,
dropout,
logits,
skip,
sink,
spad="t",
skpad="t",
dpad="t",
dvpad="t",
)
)
if receipt == 1 and bias != "bias":
specs.append(
PipelineSpec(
"qr",
mask,
bias,
lse,
dropout,
logits,
skip,
sink,
spad="t",
skpad="t",
dpad="t",
dvpad="t",
)
)
return specs
def _fwd_specs_gfx950_extra(hdim: int, hdim_v: int) -> List[PipelineSpec]:
"""Additional trload/v3 pipelines for gfx950 fp16/bf16.
Source: fmha_fwd.py CompatibilityRuleFactoryGfx950 —
qr_async_trload only supports hdims (64,64) and (128,128),
requires no logits/bias/dropout/skip.
qr_async_trload_v3 only supports (128,128), no/causal mask only.
"""
specs: List[PipelineSpec] = []
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
BOOLS,
MASKS,
BIASES,
BOOLS,
BOOLS,
BOOLS,
BOOLS,
):
if (
(hdim, hdim_v) in [(64, 64), (128, 128)]
and logits == "f"
and bias == "no"
and dropout == "f"
and skip == "f"
):
specs.append(
PipelineSpec(
"qr_async_trload",
mask,
bias,
lse,
dropout,
logits,
skip,
sink,
spad="f",
skpad="f",
dpad="f",
dvpad="f",
)
)
specs.append(
PipelineSpec(
"qr_async_trload",
mask,
bias,
lse,
dropout,
logits,
skip,
sink,
spad="f",
skpad="f",
dpad="t",
dvpad="t",
)
)
if (hdim, hdim_v) == (128, 128):
for logits, mask in itertools.product(BOOLS, ["no", "causal"]):
specs.append(
PipelineSpec(
"qr_async_trload_v3",
mask,
"no",
"f",
"f",
logits,
"f",
"f",
spad="t",
skpad="t",
dpad="f",
dvpad="f",
)
)
return specs
def _fwd_specs_fp8(hdim: int, hdim_v: int) -> List[PipelineSpec]:
"""Pipeline specs for fp8bf16/fp8fp32.
Source: fmha_fwd.py KernelComponentFactoryGfx9._DT_FP8 pipelines —
hdim=64 uses 'qr' (non-async), others use 'qr_async'.
FP8 supports pertensor and blockscale quantization (qscale).
No lse, dropout, skip, or bias for fp8.
"""
specs: List[PipelineSpec] = []
for logits, qscale, mask, bias, sink in itertools.product(
BOOLS,
["no", "pertensor", "blockscale"],
MASKS,
["no"],
BOOLS,
):
tag = "qr" if hdim == 64 else "qr_async"
specs.append(
PipelineSpec(
tag,
mask,
bias,
"f",
"f",
logits,
"f",
sink,
qscale=qscale,
spad="t",
skpad="f",
dpad="t",
dvpad="t",
)
)
specs.append(
PipelineSpec(
tag,
mask,
bias,
"f",
"f",
logits,
"f",
sink,
qscale=qscale,
spad="t",
skpad="t",
dpad="t",
dvpad="t",
)
)
return specs
def _fwd_specs_fp32(hdim: int, hdim_v: int) -> List[PipelineSpec]:
"""Pipeline specs for fp32.
Source: fmha_fwd.py KernelComponentFactoryGfx9._DT_FP32 —
always uses 'qr' pipeline (no async for fp32).
Full feature set (mask, bias, lse, dropout, logits, etc.).
"""
specs: List[PipelineSpec] = []
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
BOOLS,
MASKS,
BIASES,
BOOLS,
BOOLS,
BOOLS,
BOOLS,
):
specs.append(
PipelineSpec(
"qr",
mask,
bias,
lse,
dropout,
logits,
skip,
sink,
spad="f",
skpad="f",
dpad="f",
dvpad="f",
)
)
specs.append(
PipelineSpec(
"qr",
mask,
bias,
lse,
dropout,
logits,
skip,
sink,
spad="f",
skpad="t",
dpad="f",
dvpad="f",
)
)
specs.append(
PipelineSpec(
"qr",
mask,
bias,
lse,
dropout,
logits,
skip,
sink,
spad="t",
skpad="t",
dpad="t",
dvpad="t",
)
)
return specs
def get_pipelines_for_config(
arch: str,
dtype: str,
hdim: int,
hdim_v: int,
receipt: int = 0,
) -> List[PipelineSpec]:
"""Get all valid pipeline specs for a given (arch, dtype, hdim, hdim_v, receipt)."""
if dtype in DT_FP32:
specs = _fwd_specs_fp32(hdim, hdim_v)
elif dtype in DT_FP16_BF16:
specs = _fwd_specs_fp16bf16(hdim, hdim_v, receipt)
if arch == "gfx950":
specs.extend(_fwd_specs_gfx950_extra(hdim, hdim_v))
elif dtype in DT_FP8 or dtype in DT_FP8FP32:
specs = _fwd_specs_fp8(hdim, hdim_v)
else:
return []
return [
s
for s in specs
if check_logits_bias(s.logits, s.bias) and receipt_filter(receipt, dtype, s)
]
# --- SplitKV ---
def get_splitkv_pipelines(
dtype: str, hdim: int, receipt: int = 0
) -> List[SplitKVPipelineSpec]:
"""Split-KV main kernel pipelines."""
specs: List[SplitKVPipelineSpec] = []
SPLITKV_MASKS = ["no", "causal"]
if dtype in DT_FP16_BF16:
for logits, mask, bias, pagedkv, sink in itertools.product(
BOOLS, SPLITKV_MASKS, BIASES, BOOLS, BOOLS
):
if not check_logits_bias(logits, bias):
continue
specs.append(
SplitKVPipelineSpec(
"qr",
mask,
bias,
logits,
sink,
pagedkv,
spad="f",
skpad="t",
dpad="f",
dvpad="f",
)
)
specs.append(
SplitKVPipelineSpec(
"qr",
mask,
bias,
logits,
sink,
pagedkv,
spad="t",
skpad="f",
dpad="f",
dvpad="f",
)
)
specs.append(
SplitKVPipelineSpec(
"qr",
mask,
bias,
logits,
sink,
pagedkv,
spad="t",
skpad="t",
dpad="f",
dvpad="f",
)
)
specs.append(
SplitKVPipelineSpec(
"qr",
mask,
bias,
logits,
sink,
pagedkv,
spad="t",
skpad="t",
dpad="t",
dvpad="t",
)
)
elif dtype in ("fp8", "bf8"):
for logits, mask, bias in itertools.product(BOOLS, SPLITKV_MASKS, BIASES):
if not check_logits_bias(logits, bias):
continue
specs.append(
SplitKVPipelineSpec(
"qr",
mask,
bias,
logits,
"f",
"f",
squant="t",
spad="f",
skpad="f",
dpad="f",
dvpad="f",
)
)
specs.append(
SplitKVPipelineSpec(
"qr",
mask,
bias,
logits,
"f",
"f",
squant="t",
spad="t",
skpad="t",
dpad="f",
dvpad="f",
)
)
if receipt != 0:
specs = [s for s in specs if _splitkv_receipt_filter(receipt, dtype, s)]
return specs
def _splitkv_receipt_filter(
receipt: int, dtype: str, spec: SplitKVPipelineSpec
) -> bool:
if receipt == 2:
return (
dtype in ("fp16", "bf16")
and spec.bias in ("no", "alibi")
and spec.squant == "f"
and spec.sink == "f"
)
if receipt == 4:
return (
dtype in ("fp16", "bf16")
and spec.bias in ("no", "bias")
and spec.squant == "f"
and spec.sink == "f"
)
if receipt == 200:
return dtype in ("fp16", "bf16") and spec.squant == "f"
if receipt == 600:
return dtype in ("fp16", "bf16") and spec.squant == "f"
if receipt in (800, 801):
return dtype == "fp32"
return True
def get_splitkv_combine_pipelines(
dtype: str, receipt: int = 0
) -> List[SplitKVCombineSpec]:
"""Split-KV combine kernel pipelines."""
specs: List[SplitKVCombineSpec] = []
squant = "t" if dtype in ("fp8", "bf8") else "f"
if dtype in DT_FP16_BF16:
for spad, dvpad, lse in itertools.product(BOOLS, BOOLS, BOOLS):
specs.append(SplitKVCombineSpec(spad, dvpad, lse, squant))
elif dtype in ("fp8", "bf8"):
for spad, dvpad in itertools.product(BOOLS, BOOLS):
specs.append(SplitKVCombineSpec(spad, dvpad, "f", squant))
return specs
# --- PagedKV ---
def get_pagedkv_pipelines(
dtype: str, hdim: int, receipt: int = 0
) -> List[PipelineSpec]:
"""PagedKV prefill pipelines."""
specs: List[PipelineSpec] = []
if dtype in DT_FP16_BF16:
for logits, mask, bias, sink in itertools.product(BOOLS, MASKS, BIASES, BOOLS):
if not check_logits_bias(logits, bias):
continue
specs.append(
PipelineSpec(
"qr_pagedkv",
mask,
bias,
"f",
"f",
logits,
"f",
sink,
spad="t",
skpad="f",
dpad="f",
dvpad="f",
)
)
specs.append(
PipelineSpec(
"qr_pagedkv",
mask,
bias,
"f",
"f",
logits,
"f",
sink,
spad="t",
skpad="t",
dpad="f",
dvpad="f",
)
)
elif dtype in ("fp8", "bf8"):
for logits, mask, bias in itertools.product(BOOLS, MASKS, BIASES):
if not check_logits_bias(logits, bias):
continue
specs.append(
PipelineSpec(
"qr_pagedkv",
mask,
bias,
"f",
"f",
logits,
"f",
"f",
spad="f",
skpad="f",
dpad="f",
dvpad="f",
)
)
specs.append(
PipelineSpec(
"qr_pagedkv",
mask,
bias,
"f",
"f",
logits,
"f",
"f",
spad="t",
skpad="t",
dpad="f",
dvpad="f",
)
)
if receipt != 0:
specs = [s for s in specs if receipt_filter(receipt, dtype, s)]
return specs
# --- AppendKV ---
def get_appendkv_pipelines(
dtype: str, hdim: int, receipt: int = 0
) -> List[AppendKVPipelineSpec]:
"""Append-KV pipelines."""
specs: List[AppendKVPipelineSpec] = []
if dtype in DT_FP16_BF16:
for pagedkv in ["t", "f"]:
specs.append(
AppendKVPipelineSpec(
rope="none",
pagedkv=pagedkv,
spad="f",
skpad="t",
dpad="f",
dvpad="f",
)
)
specs.append(
AppendKVPipelineSpec(
rope="none",
pagedkv=pagedkv,
spad="t",
skpad="t",
dpad="t",
dvpad="t",
)
)
specs.append(
AppendKVPipelineSpec(
rope="interleaved",
pagedkv=pagedkv,
spad="f",
skpad="t",
dpad="t",
dvpad="f",
)
)
specs.append(
AppendKVPipelineSpec(
rope="interleaved",
pagedkv=pagedkv,
spad="t",
skpad="t",
dpad="t",
dvpad="t",
)
)
specs.append(
AppendKVPipelineSpec(
rope="half_rotated",
pagedkv=pagedkv,
spad="f",
skpad="t",
dpad="t",
dvpad="f",
)
)
specs.append(
AppendKVPipelineSpec(
rope="half_rotated",
pagedkv=pagedkv,
spad="t",
skpad="t",
dpad="t",
dvpad="t",
)
)
elif dtype in ("fp8", "bf8"):
specs.append(
AppendKVPipelineSpec(
rope="none", pagedkv="f", spad="t", skpad="t", dpad="t", dvpad="t"
)
)
return specs
# --- Batch Prefill ---
def get_batch_prefill_pipelines(
dtype: str, hdim: int, receipt: int = 0
) -> List[BatchPrefillPipelineSpec]:
"""Batch prefill pipelines."""
specs: List[BatchPrefillPipelineSpec] = []
PREFILL_MASKS = ["no", "causal"]
if dtype in DT_FP16_BF16:
for logits, mask, bias, lse, dropout, kvl, kvt in itertools.product(
BOOLS,
PREFILL_MASKS,
BIASES,
BOOLS,
BOOLS,
["vectorized", "linear"],
["vllm", "sglang"],
):
if not check_logits_bias(logits, bias):
continue
specs.append(
BatchPrefillPipelineSpec(
mask,
bias,
logits,
"f",
lse,
dropout,
"f",
page_size=0,
kv_memory_layout=kvl,
kv_lookup_table=kvt,
)
)
elif dtype == "fp8bf16":
for logits, qscale, mask, bias, kvl, kvt in itertools.product(
BOOLS,
["pertensor", "kv_blockscale"],
MASKS,
["no"],
["vectorized", "linear"],
["vllm", "sglang"],
):
if not check_logits_bias(logits, bias):
continue
specs.append(
BatchPrefillPipelineSpec(
mask,
bias,
logits,
"f",
"f",
"f",
"f",
qscale=qscale,
page_size=0,
kv_memory_layout=kvl,
kv_lookup_table=kvt,
)
)
return specs
# --- BWD ---
def get_bwd_dq_dk_dv_pipelines(dtype: str, receipt: int = 0) -> List[BwdPipelineSpec]:
"""BWD dq_dk_dv feature product."""
if dtype not in DT_FP16_BF16:
return []
specs: List[BwdPipelineSpec] = []
for mask, bias, dbias, dropout, deterministic in itertools.product(
MASKS,
BIASES,
BOOLS,
BWD_DROPOUTS,
BOOLS,
):
if bias != "bias" and dbias == "t":
continue
for dpad, dvpad in BWD_PAD_COMBOS:
specs.append(
BwdPipelineSpec(
"bwd_dq_dk_dv",
mask,
bias,
dbias,
dropout,
deterministic,
dpad=dpad,
dvpad=dvpad,
)
)
return specs
def get_bwd_dq_dk_dv_extra_pipelines(
dtype: str, is_small: bool = False, receipt: int = 0
) -> List[BwdPipelineSpec]:
"""BWD dq_dk_dv extra tile pipelines (reduced feature set)."""
if dtype not in DT_FP16_BF16:
return []
specs: List[BwdPipelineSpec] = []
dropouts = BWD_SMALL_DROPOUTS if is_small else BWD_DROPOUTS
for mask, bias, dbias, dropout, deterministic in itertools.product(
MASKS,
BIASES,
BOOLS,
dropouts,
BOOLS,
):
if bias != "bias" and dbias == "t":
continue
for dpad, dvpad in BWD_EXTRA_PAD_COMBOS:
specs.append(
BwdPipelineSpec(
"bwd_dq_dk_dv",
mask,
bias,
dbias,
dropout,
deterministic,
dpad=dpad,
dvpad=dvpad,
)
)
return specs
def get_bwd_dot_do_o_pipelines(dtype: str) -> List[BwdPipelineSpec]:
"""BWD dot_do_o: spad x dvpad variants."""
if dtype not in DT_FP16_BF16:
return []
return [
BwdPipelineSpec("bwd_dot_do_o", spad=s, dvpad=d)
for s, d in itertools.product(BOOLS, BOOLS)
]
def get_bwd_convert_dq_pipelines(dtype: str, hdim: int = 0) -> List[BwdPipelineSpec]:
"""BWD convert_dq: spad x deterministic x dpad."""
if dtype not in DT_FP16_BF16:
return []
dpads = ["f", "t", "8"] if hdim == 128 else BOOLS
return [
BwdPipelineSpec("bwd_convert_dq", spad=s, deterministic=d, dpad=dp)
for s, d, dp in itertools.product(BOOLS, BOOLS, dpads)
]
# =============================================================================
# Tile compatibility (used by expand to double-check)
# =============================================================================
def tile_compatible(
arch: str,
dtype: str,
hdim: int,
hdim_v: int,
pipeline_tag: str,
tile: Tuple[int, ...],
) -> bool:
"""Check if a tile tuple passes arch-specific constraints (subset of tile_passes_all_constraints)."""
bm0, bn0, bk0 = tile[0], tile[1], tile[2]
if not check_gfx9_tile_constraints(
dtype, hdim, hdim_v, pipeline_tag, bm0, bn0, bk0
):
return False
if arch == "gfx950":
if not check_gfx950_tile_constraints(hdim, hdim_v, pipeline_tag, bm0, bn0):
return False
# Use default warp for mfma check
wn0, wk0 = 32, 16
warp_classes = WARP_CLASSES.get(dtype, [(32, 32, 16)])
if warp_classes:
_, wn0, wk0 = warp_classes[0]
if not check_qr_mfma_insts(arch, hdim, pipeline_tag, bn0, bk0, wn0, wk0):
return False
return True
# =============================================================================
# BWD wave/warp lookup
# =============================================================================
def bwd_dq_wave_warp(tile, hq, trload=False):
"""Look up BWD wave/warp config for a tile."""
trl = "t" if trload else "f"
key = (tile[0], tile[1], tile[2], trl)
entry = BWD_DQ_WAVE_WARP.get(key)
if entry is None:
for k, v in BWD_DQ_WAVE_WARP.items():
if k[:3] == (tile[0], tile[1], tile[2]):
entry = v
break
if entry is None:
bn0 = tile[1]
wn = min(4, max(1, bn0 // 32))
return {
"wave_m0": 1,
"wave_n0": wn,
"wave_k0": 1,
"wave_m1": 4,
"wave_n1": 1,
"wave_k1": 1,
"wave_m2": 1,
"wave_n2": wn,
"wave_k2": 1,
"warp_m0": 16,
"warp_n0": 16,
"warp_k0": 32,
"warp_m1": 16,
"warp_n1": 16,
"warp_k1": 16,
"warp_m2": 16,
"warp_n2": 16,
"warp_k2": 16,
}
w = entry["wave"]
wk1 = entry["warp_k1"]
return {
"wave_m0": w[0],
"wave_n0": w[1],
"wave_k0": w[2],
"wave_m1": w[3],
"wave_n1": w[4],
"wave_k1": w[5],
"wave_m2": w[6],
"wave_n2": w[7],
"wave_k2": w[8],
"warp_m0": 16,
"warp_n0": 16,
"warp_k0": 32,
"warp_m1": 16,
"warp_n1": 16,
"warp_k1": wk1,
"warp_m2": 16,
"warp_n2": 16,
"warp_k2": 16,
}
# =============================================================================
# Instance expansion
# =============================================================================
VARIANT_TO_FAMILY = {
"fwd": "fwd",
"bwd": "bwd_dq_dk_dv",
"splitkv": "fwd_splitkv",
"appendkv": "fwd_appendkv",
"pagedkv": "fwd_pagedkv",
"batch_prefill": "batch_prefill",
}
MODES = ["batch", "group"]
_MASK_MAP = {"no": "no", "causal": "top_left", "generic": "generic"}
_BIAS_MAP = {"no": "no", "bias": "bias", "alibi": "alibi"}
def _pad_val(s: str) -> int:
if s == "f":
return 0
if s == "t":
return 1
return int(s)
def expand_sweep(
config_path: Optional[str],
arch: str,
receipt: int = 0,
mode: str = "rules",
restrict_hdims: Optional[List[Tuple[int, int]]] = None,
default_variant: str = "fwd",
) -> List[FmhaKernelConfig]:
"""Expand sweep into full kernel instance list.
Args:
config_path: Path to JSON sweep config, or None for defaults
(only valid with mode="exhaustive").
arch: Target GPU arch ("gfx950" etc.).
receipt: Receipt level (0 = full, higher = filtered).
mode: "rules" applies tile_passes_all_constraints + receipt-driven
pipeline×feature coupling. "exhaustive" skips constraints and uses
a raw cartesian feature product (variant must be "fwd").
restrict_hdims: If set, only generate configs for these (hq, hv) pairs.
default_variant: Variant to use when config_path is None.
"""
if config_path is None:
if mode != "exhaustive":
raise ValueError("config_path is required for mode='rules'")
config = {"variant": default_variant, "trait_config": {}}
else:
with open(config_path) as f:
config = json.load(f)
variant = config.get("variant", default_variant)
# Build allow-list filters from JSON trait_config
trait_cfg = config.get("trait_config", {})
def _allow(key: str) -> Optional[Set[str]]:
entry = trait_cfg.get(key)
if entry is None:
return None
return set(entry.get("values", []))
allowed_dtypes = _allow("data_type")
allowed_pipes = _allow("pipeline")
allowed_masks = _allow("mask")
allowed_biases = _allow("bias")
allowed_modes = _allow("mode")
allowed_lse = _allow("lse")
allowed_dropout = _allow("dropout")
allowed_logits = _allow("logits")
allowed_sink = _allow("sink")
allowed_paged_kv = _allow("paged_kv")
# block_per_cu: int or list of ints to sweep
bpc_entry = trait_cfg.get("block_per_cu", {})
block_per_cu_values = bpc_entry.get("values", [-1])
if isinstance(block_per_cu_values, int):
block_per_cu_values = [block_per_cu_values]
# Intersect with arch support
arch_dtypes = set(ARCH_DTYPES.get(arch, ARCH_DTYPES.get("gfx950", [])))
dtypes = (
sorted(allowed_dtypes & arch_dtypes) if allowed_dtypes else sorted(arch_dtypes)
)
configs: List[FmhaKernelConfig] = []
if mode == "exhaustive":
if variant == "fwd":
configs = _expand_fwd_exhaustive(
arch,
dtypes,
allowed_pipes,
allowed_masks,
allowed_biases,
allowed_modes,
allowed_lse,
allowed_dropout,
allowed_logits,
allowed_sink,
block_per_cu_values,
restrict_hdims,
)
elif variant == "splitkv":
configs = _expand_splitkv_exhaustive(
arch,
dtypes,
allowed_masks,
allowed_biases,
allowed_modes,
allowed_logits,
allowed_sink,
allowed_paged_kv,
restrict_hdims,
)
elif variant == "pagedkv":
configs = _expand_pagedkv_exhaustive(
arch,
dtypes,
allowed_masks,
allowed_biases,
allowed_modes,
restrict_hdims,
)
elif variant == "bwd":
configs = _expand_bwd_exhaustive(
arch,
dtypes,
allowed_masks,
allowed_biases,
allowed_modes,
restrict_hdims,
)
elif variant in ("appendkv", "batch_prefill"):
# These have fixed tiles (no tile sweep), so exhaustive = rules mode
if variant == "appendkv":
configs = _expand_appendkv(
arch, dtypes, 0, restrict_hdims=restrict_hdims
)
else:
configs = _expand_batch_prefill(
arch,
dtypes,
0,
allowed_masks,
allowed_biases,
restrict_hdims=restrict_hdims,
)
else:
raise ValueError(f"Exhaustive mode not supported for variant {variant!r}")
elif variant == "fwd":
configs = _expand_fwd(
arch,
dtypes,
receipt,
allowed_pipes,
allowed_masks,
allowed_biases,
allowed_modes,
allowed_lse,
allowed_dropout,
allowed_logits,
allowed_sink,
block_per_cu_values,
restrict_hdims=restrict_hdims,
)
elif variant == "splitkv":
configs = _expand_splitkv(
arch,
dtypes,
receipt,
allowed_masks,
allowed_biases,
allowed_modes,
allowed_logits,
allowed_sink,
allowed_paged_kv,
restrict_hdims=restrict_hdims,
)
elif variant == "pagedkv":
configs = _expand_pagedkv(
arch,
dtypes,
receipt,
allowed_masks,
allowed_biases,
allowed_modes,
restrict_hdims=restrict_hdims,
)
elif variant == "appendkv":
configs = _expand_appendkv(arch, dtypes, receipt, restrict_hdims=restrict_hdims)
elif variant == "batch_prefill":
configs = _expand_batch_prefill(
arch,
dtypes,
receipt,
allowed_masks,
allowed_biases,
restrict_hdims=restrict_hdims,
)
elif variant == "bwd":
configs = _expand_bwd(
arch,
dtypes,
receipt,
allowed_masks,
allowed_biases,
allowed_modes,
restrict_hdims=restrict_hdims,
)
# Dedup
seen: set = set()
unique: List[FmhaKernelConfig] = []
for c in configs:
if c.name not in seen:
seen.add(c.name)
unique.append(c)
return unique
def _build_fwd_kernel_config(
*,
arch: str,
dtype: str,
mode: str,
hq: int,
hv: int,
pipeline: str,
tc: FmhaTileConfig,
pad_s: int = 0,
pad_sk: int = 0,
pad_d: int = 0,
pad_dv: int = 0,
mask: str = "no",
bias: str = "no",
lse: bool = False,
dropout: bool = False,
logits: bool = False,
sink: bool = False,
skip_min_seqlen_q: bool = False,
qscale: str = "no",
block_per_cu: int = -1,
) -> FmhaKernelConfig:
"""Single source of truth for fwd FmhaKernelConfig kwargs derived from a tile."""
return FmhaKernelConfig(
family="fwd",
data_type=dtype,
mode=mode,
hdim_q=hq,
hdim_v=hv,
pipeline=pipeline,
tile_m0=tc.bm0,
tile_n0=tc.bn0,
tile_k0=tc.bk0,
tile_n1=tc.bn1,
tile_k1=tc.bk1,
tile_k0max=tc.bk0max,
wave_m0=tc.rm0,
wave_n0=1,
wave_k0=1,
wave_m1=tc.rm0,
wave_n1=1,
wave_k1=1,
warp_m0=tc.wm0,
warp_n0=tc.wn0,
warp_k0=tc.wk0,
warp_m1=tc.wm1,
warp_n1=tc.wn1,
warp_k1=tc.wk1,
pad_s=pad_s,
pad_sk=pad_sk,
pad_d=pad_d,
pad_dv=pad_dv,
mask=mask,
bias=bias,
lse=lse,
dropout=dropout,
logits=logits,
sink=sink,
skip_min_seqlen_q=skip_min_seqlen_q,
qscale=qscale,
block_per_cu=block_per_cu,
gfx_arch=arch,
)
def _expand_fwd(
arch,
dtypes,
receipt,
allowed_pipes,
allowed_masks,
allowed_biases,
allowed_modes,
allowed_lse,
allowed_dropout,
allowed_logits,
allowed_sink,
block_per_cu_values=None,
restrict_hdims=None,
):
if block_per_cu_values is None:
block_per_cu_values = [-1]
configs = []
for dtype in dtypes:
hdims = SUPPORTED_HDIMS.get(dtype, [])
if restrict_hdims is not None:
hdims = [hv for hv in hdims if hv in restrict_hdims]
for hq, hv in hdims:
pipeline_specs = get_pipelines_for_config(arch, dtype, hq, hv, receipt)
_tile_cache: Dict[str, List[FmhaTileConfig]] = {}
for mode in MODES:
if allowed_modes is not None and mode not in allowed_modes:
continue
for spec in pipeline_specs:
if not check_group_mode_padding(mode, spec.spad, spec.skpad):
continue
if allowed_pipes is not None and spec.tag not in allowed_pipes:
continue
mm = _MASK_MAP.get(spec.mask, spec.mask)
mb = _BIAS_MAP.get(spec.bias, spec.bias)
lv = spec.lse == "t"
dv = spec.dropout == "t"
lgv = spec.logits == "t"
sv = spec.sink == "t"
skv = spec.skip == "t"
if allowed_masks is not None and mm not in allowed_masks:
continue
if allowed_biases is not None and mb not in allowed_biases:
continue
if allowed_lse is not None and lv not in allowed_lse:
continue
if allowed_dropout is not None and dv not in allowed_dropout:
continue
if allowed_logits is not None and lgv not in allowed_logits:
continue
if allowed_sink is not None and sv not in allowed_sink:
continue
if spec.tag not in _tile_cache:
_tile_cache[spec.tag] = generate_fwd_tiles(
arch, dtype, hq, hv, spec.tag
)
for tc in _tile_cache[spec.tag]:
t6 = (tc.bm0, tc.bn0, tc.bk0, tc.bn1, tc.bk1, tc.bk0max)
if not tile_compatible(arch, dtype, hq, hv, spec.tag, t6):
continue
for bpc in block_per_cu_values:
configs.append(
_build_fwd_kernel_config(
arch=arch,
dtype=dtype,
mode=mode,
hq=hq,
hv=hv,
pipeline=spec.tag,
tc=tc,
pad_s=_pad_val(spec.spad),
pad_sk=_pad_val(spec.skpad),
pad_d=_pad_val(spec.dpad),
pad_dv=_pad_val(spec.dvpad),
mask=mm,
bias=mb,
lse=lv,
dropout=dv,
logits=lgv,
sink=sv,
skip_min_seqlen_q=skv,
qscale=spec.qscale,
block_per_cu=bpc,
)
)
return configs
def _expand_fwd_exhaustive(
arch,
dtypes,
allowed_pipes,
allowed_masks,
allowed_biases,
allowed_modes,
allowed_lse,
allowed_dropout,
allowed_logits,
allowed_sink,
block_per_cu_values,
restrict_hdims,
):
"""Exhaustive fwd expansion: ALL tiles (no constraint filter) × full feature cross-product.
Differs from _expand_fwd in two ways:
1. Tiles come from generate_fwd_tiles(..., apply_constraints=False)
2. Features are a raw cartesian product (no pipeline-receipt coupling)
Used by --tiles=exhaustive in the benchmark to discover compilable tiles
that the rules engine rejects.
"""
pipelines = (
sorted(allowed_pipes)
if allowed_pipes
else ["qr", "qr_async", "qr_async_trload", "qr_async_trload_v3"]
)
modes = sorted(allowed_modes) if allowed_modes else MODES
masks = (
sorted(allowed_masks) if allowed_masks else ["no", "top_left", "bottom_right"]
)
biases = sorted(allowed_biases) if allowed_biases else ["no", "bias", "alibi"]
lse_vals = sorted(allowed_lse) if allowed_lse else [False, True]
dropout_vals = sorted(allowed_dropout) if allowed_dropout else [False, True]
logits_vals = sorted(allowed_logits) if allowed_logits else [False, True]
sink_vals = sorted(allowed_sink) if allowed_sink else [False]
bpc_vals = block_per_cu_values if block_per_cu_values else [-1, 1, 2]
configs: List[FmhaKernelConfig] = []
for dtype in dtypes:
hdims = SUPPORTED_HDIMS.get(dtype, [])
if restrict_hdims is not None:
hdims = [hv for hv in hdims if hv in restrict_hdims]
for hq, hv in hdims:
for pipeline in pipelines:
tiles = generate_fwd_tiles(
arch, dtype, hq, hv, pipeline, apply_constraints=False
)
for tc in tiles:
for mode, mask, bias, lv, dv, lgv, sv, bpc in itertools.product(
modes,
masks,
biases,
lse_vals,
dropout_vals,
logits_vals,
sink_vals,
bpc_vals,
):
configs.append(
_build_fwd_kernel_config(
arch=arch,
dtype=dtype,
mode=mode,
hq=hq,
hv=hv,
pipeline=pipeline,
tc=tc,
mask=mask,
bias=bias,
lse=lv,
dropout=dv,
logits=lgv,
sink=sv,
block_per_cu=bpc,
)
)
return configs
def _expand_splitkv_exhaustive(
arch,
dtypes,
allowed_masks,
allowed_biases,
allowed_modes,
allowed_logits,
allowed_sink,
allowed_paged_kv,
restrict_hdims,
):
"""Exhaustive splitkv: ALL tiles (no LDS filter) × full feature product."""
modes = sorted(allowed_modes) if allowed_modes else MODES
masks = (
sorted(allowed_masks) if allowed_masks else ["no", "top_left", "bottom_right"]
)
biases = sorted(allowed_biases) if allowed_biases else ["no", "bias", "alibi"]
logits_vals = sorted(allowed_logits) if allowed_logits else [False, True]
sink_vals = sorted(allowed_sink) if allowed_sink else [False]
configs: List[FmhaKernelConfig] = []
for dtype in dtypes:
hdims = SUPPORTED_HDIMS.get(dtype, [])
if restrict_hdims is not None:
hdims = [hv for hv in hdims if hv in restrict_hdims]
for hq, hv in hdims:
tiles = generate_splitkv_tiles(arch, dtype, hq, hv, apply_constraints=False)
for tc in tiles:
for mode, mask, bias, lgv, sv in itertools.product(
modes,
masks,
biases,
logits_vals,
sink_vals,
):
configs.append(
FmhaKernelConfig(
family="fwd_splitkv",
data_type=dtype,
mode=mode,
hdim_q=hq,
hdim_v=hv,
pipeline="qr",
tile_m0=tc.bm0,
tile_n0=tc.bn0,
tile_k0=tc.bk0,
tile_n1=tc.bn1,
tile_k1=tc.bk1,
tile_k0max=tc.bk0max,
wave_m0=tc.rm0,
wave_n0=1,
wave_k0=1,
wave_m1=tc.rm0,
wave_n1=1,
wave_k1=1,
warp_m0=tc.wm0,
warp_n0=tc.wn0,
warp_k0=tc.wk0,
warp_m1=tc.wm1,
warp_n1=tc.wn1,
warp_k1=tc.wk1,
mask=mask,
bias=bias,
lse=True,
logits=lgv,
sink=sv,
gfx_arch=arch,
)
)
return configs
def _expand_pagedkv_exhaustive(
arch,
dtypes,
allowed_masks,
allowed_biases,
allowed_modes,
restrict_hdims,
):
"""Exhaustive pagedkv: ALL tiles (no LDS filter) × full feature product."""
modes = sorted(allowed_modes) if allowed_modes else MODES
masks = (
sorted(allowed_masks) if allowed_masks else ["no", "top_left", "bottom_right"]
)
biases = sorted(allowed_biases) if allowed_biases else ["no", "bias", "alibi"]
configs: List[FmhaKernelConfig] = []
for dtype in dtypes:
hdims = SUPPORTED_HDIMS.get(dtype, [])
if restrict_hdims is not None:
hdims = [hv for hv in hdims if hv in restrict_hdims]
for hq, hv in hdims:
tiles = generate_pagedkv_tiles(arch, dtype, hq, hv, apply_constraints=False)
for tc in tiles:
for mode, mask, bias in itertools.product(modes, masks, biases):
configs.append(
FmhaKernelConfig(
family="fwd_pagedkv",
data_type=dtype,
mode=mode,
hdim_q=hq,
hdim_v=hv,
pipeline="qr_pagedkv",
tile_m0=tc.bm0,
tile_n0=tc.bn0,
tile_k0=tc.bk0,
tile_n1=tc.bn1,
tile_k1=tc.bk1,
tile_k0max=tc.bk0max,
wave_m0=tc.rm0,
wave_n0=1,
wave_k0=1,
wave_m1=tc.rm0,
wave_n1=1,
wave_k1=1,
warp_m0=tc.wm0,
warp_n0=tc.wn0,
warp_k0=tc.wk0,
warp_m1=tc.wm1,
warp_n1=tc.wn1,
warp_k1=tc.wk1,
mask=mask,
bias=bias,
paged_kv=True,
gfx_arch=arch,
)
)
return configs
def _expand_bwd_exhaustive(
arch,
dtypes,
allowed_masks,
allowed_biases,
allowed_modes,
restrict_hdims,
):
"""Exhaustive bwd: ALL tiles (no LDS filter) × full feature product.
Note: BWD uses spec-defined fixed tiles for dq_dk_dv, but we can still
exhaust the dot_do_o and convert_dq with unfiltered tile generation.
For dq_dk_dv we use generate_bwd_tiles(apply_constraints=False) since
CK's bwd tile tables are hand-curated and the exhaustive sweep should
explore beyond them.
"""
modes = sorted(allowed_modes) if allowed_modes else MODES
masks = (
sorted(allowed_masks) if allowed_masks else ["no", "top_left", "bottom_right"]
)
biases = sorted(allowed_biases) if allowed_biases else ["no", "bias", "alibi"]
deterministic_vals = [False, True]
dropout_vals = ["no", "p", "rp"]
configs: List[FmhaKernelConfig] = []
for dtype in dtypes:
if dtype not in ("fp16", "bf16"):
continue
# dot_do_o — fixed tile, just sweep features
dot_specs = get_bwd_dot_do_o_pipelines(dtype)
for hd in BWD_DOT_DO_O_HDIMS:
if restrict_hdims is not None and (hd, hd) not in restrict_hdims:
continue
for mode in modes:
for spec in dot_specs:
configs.append(
FmhaKernelConfig(
family="bwd_dot_do_o",
data_type=dtype,
mode=mode,
hdim_q=hd,
hdim_v=hd,
pipeline="qr",
tile_m0=64,
pad_s=_pad_val(spec.spad),
pad_dv=_pad_val(spec.dvpad),
gfx_arch=arch,
)
)
# dq_dk_dv — exhaustive tiles
hdims = SUPPORTED_HDIMS.get(dtype, [])
if restrict_hdims is not None:
hdims = [hv for hv in hdims if hv in restrict_hdims]
for hq, hv in hdims:
tiles = generate_bwd_tiles(arch, dtype, hq, hv, apply_constraints=False)
for tc in tiles:
for mode, mask, bias, dropout, det in itertools.product(
modes,
masks,
biases,
dropout_vals,
deterministic_vals,
):
ww = bwd_dq_wave_warp((tc.bm0, tc.bn0, tc.bk0), hq)
configs.append(
FmhaKernelConfig(
family="bwd_dq_dk_dv",
data_type=dtype,
mode=mode,
hdim_q=hq,
hdim_v=hv,
pipeline="qr",
tile_m0=tc.bm0,
tile_n0=tc.bn0,
tile_k0=tc.bk0,
tile_n1=tc.bn1,
tile_k1=tc.bk1,
tile_k0max=tc.bk0max,
mask=mask,
bias=bias,
dropout=(dropout != "no"),
dropout_variant=dropout,
deterministic=det,
gfx_arch=arch,
**ww,
)
)
# convert_dq — no tile sweep (fixed tile), just feature sweep
for hd in BWD_CONVERT_DQ_HDIMS:
if restrict_hdims is not None and (hd, hd) not in restrict_hdims:
continue
for mode, det in itertools.product(modes, deterministic_vals):
configs.append(
FmhaKernelConfig(
family="bwd_convert_dq",
data_type=dtype,
mode=mode,
hdim_q=hd,
hdim_v=hd,
pipeline="qr",
tile_m0=64,
deterministic=det,
gfx_arch=arch,
)
)
return configs
def _expand_splitkv(
arch,
dtypes,
receipt,
allowed_masks,
allowed_biases,
allowed_modes,
allowed_logits=None,
allowed_sink=None,
allowed_paged_kv=None,
restrict_hdims=None,
):
configs = []
for dtype in dtypes:
hdims = SUPPORTED_HDIMS.get(dtype, [])
if restrict_hdims is not None:
hdims = [hv for hv in hdims if hv in restrict_hdims]
for hq, hv in hdims:
tiles = generate_splitkv_tiles(arch, dtype, hq, hv)
sk_specs = get_splitkv_pipelines(dtype, hq, receipt)
for tc in tiles:
for mode in MODES:
if allowed_modes is not None and mode not in allowed_modes:
continue
for spec in sk_specs:
if mode == "group" and not (
spec.spad == "t" and spec.skpad == "t"
):
continue
mm = _MASK_MAP.get(spec.mask, spec.mask)
mb = _BIAS_MAP.get(spec.bias, spec.bias)
if allowed_masks is not None and mm not in allowed_masks:
continue
if allowed_biases is not None and mb not in allowed_biases:
continue
lgv = spec.logits == "t"
sv = spec.sink == "t"
pkv = spec.pagedkv == "t"
if allowed_logits is not None and lgv not in allowed_logits:
continue
if allowed_sink is not None and sv not in allowed_sink:
continue
if allowed_paged_kv is not None and pkv not in allowed_paged_kv:
continue
configs.append(
FmhaKernelConfig(
family="fwd_splitkv",
data_type=dtype,
mode=mode,
hdim_q=hq,
hdim_v=hv,
pipeline=spec.tag,
tile_m0=tc.bm0,
tile_n0=tc.bn0,
tile_k0=tc.bk0,
tile_n1=tc.bn1,
tile_k1=tc.bk1,
tile_k0max=tc.bk0max,
wave_m0=tc.rm0,
wave_n0=1,
wave_k0=1,
wave_m1=tc.rm0,
wave_n1=1,
wave_k1=1,
warp_m0=tc.wm0,
warp_n0=tc.wn0,
warp_k0=tc.wk0,
warp_m1=tc.wm1,
warp_n1=tc.wn1,
warp_k1=tc.wk1,
pad_s=_pad_val(spec.spad),
pad_sk=_pad_val(spec.skpad),
pad_d=_pad_val(spec.dpad),
pad_dv=_pad_val(spec.dvpad),
mask=mm,
bias=mb,
lse=True,
logits=lgv,
sink=sv,
paged_kv=pkv,
gfx_arch=arch,
)
)
# Combine kernels
for dtype in dtypes:
comb_specs = get_splitkv_combine_pipelines(dtype, receipt)
if not comb_specs:
continue
hdims = (
SPLITKV_COMBINE_HDIMS_FP16
if dtype in ("fp16", "bf16")
else SPLITKV_COMBINE_HDIMS_FP8
if dtype in ("fp8", "bf8")
else []
)
for hv in hdims:
for mode in MODES:
if allowed_modes is not None and mode not in allowed_modes:
continue
for spec in comb_specs:
if mode == "group" and spec.spad != "t":
continue
configs.append(
FmhaKernelConfig(
family="fwd_splitkv_combine",
data_type=dtype,
mode=mode,
hdim_q=hv,
hdim_v=hv,
pipeline="splitkv_combine",
tile_m0=32,
tile_n0=hv,
tile_k0=32,
tile_n1=32,
pad_s=_pad_val(spec.spad),
pad_dv=_pad_val(spec.dvpad),
lse=(spec.lse == "t"),
gfx_arch=arch,
)
)
return configs
def _expand_pagedkv(
arch,
dtypes,
receipt,
allowed_masks,
allowed_biases,
allowed_modes,
restrict_hdims=None,
):
configs = []
for dtype in dtypes:
hdims = SUPPORTED_HDIMS.get(dtype, [])
if restrict_hdims is not None:
hdims = [hv for hv in hdims if hv in restrict_hdims]
for hq, hv in hdims:
tiles = generate_pagedkv_tiles(arch, dtype, hq, hv)
pk_specs = get_pagedkv_pipelines(dtype, hq, receipt)
for tc in tiles:
for mode in MODES:
if allowed_modes is not None and mode not in allowed_modes:
continue
for spec in pk_specs:
if mode == "group" and not (
spec.spad == "t" and spec.skpad == "t"
):
continue
mm = _MASK_MAP.get(spec.mask, spec.mask)
mb = _BIAS_MAP.get(spec.bias, spec.bias)
if allowed_masks is not None and mm not in allowed_masks:
continue
if allowed_biases is not None and mb not in allowed_biases:
continue
configs.append(
FmhaKernelConfig(
family="fwd_pagedkv",
data_type=dtype,
mode=mode,
hdim_q=hq,
hdim_v=hv,
pipeline=spec.tag,
tile_m0=tc.bm0,
tile_n0=tc.bn0,
tile_k0=tc.bk0,
tile_n1=tc.bn1,
tile_k1=tc.bk1,
tile_k0max=tc.bk0max,
wave_m0=tc.rm0,
wave_n0=1,
wave_k0=1,
wave_m1=tc.rm0,
wave_n1=1,
wave_k1=1,
warp_m0=tc.wm0,
warp_n0=tc.wn0,
warp_k0=tc.wk0,
warp_m1=tc.wm1,
warp_n1=tc.wn1,
warp_k1=tc.wk1,
pad_s=_pad_val(spec.spad),
pad_sk=_pad_val(spec.skpad),
pad_d=_pad_val(spec.dpad),
pad_dv=_pad_val(spec.dvpad),
mask=mm,
bias=mb,
logits=(spec.logits == "t"),
skip_min_seqlen_q=(spec.skip == "t"),
sink=(spec.sink == "t"),
paged_kv=True,
gfx_arch=arch,
)
)
return configs
def _expand_appendkv(arch, dtypes, receipt, restrict_hdims=None):
configs = []
for dtype in dtypes:
ak_specs = get_appendkv_pipelines(dtype, 0, receipt)
hdims = SUPPORTED_HDIMS.get(dtype, [])
if restrict_hdims is not None:
hdims = [hv for hv in hdims if hv in restrict_hdims]
for hq, hv in hdims:
for spec in ak_specs:
configs.append(
FmhaKernelConfig(
family="fwd_appendkv",
data_type=dtype,
mode="batch",
hdim_q=hq,
hdim_v=hv,
pipeline="appendkv",
tile_m0=64,
tile_n0=64,
tile_k0=hq,
tile_n1=hv,
pad_s=_pad_val(spec.spad),
pad_sk=_pad_val(spec.skpad),
pad_d=_pad_val(spec.dpad),
pad_dv=_pad_val(spec.dvpad),
rope={
"none": "none",
"interleaved": "interleaved",
"half_rotated": "half_rotated",
}.get(spec.rope, spec.rope),
paged_kv=(spec.pagedkv == "t"),
gfx_arch=arch,
)
)
return configs
def _expand_batch_prefill(
arch, dtypes, receipt, allowed_masks, allowed_biases, restrict_hdims=None
):
configs = []
page_sizes = [1, 16, 1024]
def _bp_bk1(bm0, bn0, bk0, hq):
if bm0 == 64 and bn0 == 128 and bk0 == 64 and hq == 128:
return 64
return 32
for dtype in dtypes:
hdims = SUPPORTED_HDIMS.get(dtype, [])
if restrict_hdims is not None:
hdims = [hv for hv in hdims if hv in restrict_hdims]
for hq, hv in hdims:
tiles = generate_splitkv_tiles(arch, dtype, hq, hv)
bp_specs = get_batch_prefill_pipelines(dtype, hq, receipt)
for tc in tiles:
bk1 = _bp_bk1(tc.bm0, tc.bn0, tc.bk0, hq)
for spec in bp_specs:
mm = _MASK_MAP.get(spec.mask, spec.mask)
mb = _BIAS_MAP.get(spec.bias, spec.bias)
if allowed_masks is not None and mm not in allowed_masks:
continue
if allowed_biases is not None and mb not in allowed_biases:
continue
for ps in page_sizes:
if ps == 1 and spec.kv_memory_layout != "linear":
continue
if spec.qscale == "kv_blockscale" and ps < tc.bn0:
continue
configs.append(
FmhaKernelConfig(
family="batch_prefill",
data_type=dtype,
mode="group",
hdim_q=hq,
hdim_v=hv,
pipeline="qr_async",
tile_m0=tc.bm0,
tile_n0=tc.bn0,
tile_k0=tc.bk0,
tile_n1=tc.bn1,
tile_k1=bk1,
tile_k0max=tc.bk0max,
wave_m0=tc.rm0,
wave_n0=1,
wave_k0=1,
wave_m1=tc.rm0,
wave_n1=1,
wave_k1=1,
warp_m0=tc.wm0,
warp_n0=tc.wn0,
warp_k0=tc.wk0,
warp_m1=tc.wm1,
warp_n1=tc.wn1,
warp_k1=tc.wk1,
pad_s=1,
pad_sk=1,
pad_d=1,
pad_dv=1,
mask=mm,
bias=mb,
lse=(spec.lse == "t"),
dropout=(spec.dropout == "t"),
logits=(spec.logits == "t"),
paged_kv=True,
page_size=ps,
kv_memory_layout=spec.kv_memory_layout,
kv_lookup_table=spec.kv_lookup_table,
qscale=spec.qscale,
gfx_arch=arch,
)
)
return configs
def _expand_bwd(
arch,
dtypes,
receipt,
allowed_masks,
allowed_biases,
allowed_modes,
restrict_hdims=None,
):
configs = []
for dtype in dtypes:
if dtype not in ("fp16", "bf16"):
continue
# dot_do_o
dot_specs = get_bwd_dot_do_o_pipelines(dtype)
for hd in BWD_DOT_DO_O_HDIMS:
for mode in MODES:
if allowed_modes is not None and mode not in allowed_modes:
continue
for spec in dot_specs:
if mode == "group" and spec.spad != "t":
continue
configs.append(
FmhaKernelConfig(
family="bwd_dot_do_o",
data_type=dtype,
mode=mode,
hdim_q=hd,
hdim_v=hd,
pipeline="qr",
tile_m0=64,
pad_s=_pad_val(spec.spad),
pad_dv=_pad_val(spec.dvpad),
gfx_arch=arch,
)
)
# dq_dk_dv: main tiles
dq_specs = get_bwd_dq_dk_dv_pipelines(dtype, receipt)
for (hq, hv), tile in sorted(BWD_DQ_DK_DV_TILES_FP16.items()):
for mode in MODES:
if allowed_modes is not None and mode not in allowed_modes:
continue
for spec in dq_specs:
mm = _MASK_MAP.get(spec.mask, spec.mask)
mb = _BIAS_MAP.get(spec.bias, spec.bias)
if allowed_masks is not None and mm not in allowed_masks:
continue
if allowed_biases is not None and mb not in allowed_biases:
continue
ww = bwd_dq_wave_warp(tile, hq)
configs.append(
FmhaKernelConfig(
family="bwd_dq_dk_dv",
data_type=dtype,
mode=mode,
hdim_q=hq,
hdim_v=hv,
pipeline="qr",
tile_m0=tile[0],
tile_n0=tile[1],
tile_k0=tile[2],
tile_n1=tile[3] if len(tile) > 3 else hv,
tile_k1=tile[4] if len(tile) > 4 else tile[2],
tile_k0max=tile[5] if len(tile) > 5 else hq,
tile_bwd6=tile[6] if len(tile) > 6 else 0,
tile_bwd7=tile[7] if len(tile) > 7 else 0,
tile_bwd8=tile[8] if len(tile) > 8 else 0,
pad_s=_pad_val(spec.spad),
pad_sk=_pad_val(spec.skpad),
pad_d=_pad_val(spec.dpad),
pad_dv=_pad_val(spec.dvpad),
mask=mm,
bias=mb,
dbias=(spec.dbias == "t"),
dropout=(spec.dropout != "no"),
dropout_variant=spec.dropout,
deterministic=(spec.deterministic == "t"),
gfx_arch=arch,
**ww,
)
)
# dq_dk_dv: extra tiles
for (hq, hv), extra_entries in BWD_DQ_DK_DV_EXTRA_TILES.items():
for tile, tag, is_batch_only in extra_entries:
dq_extra_specs = get_bwd_dq_dk_dv_extra_pipelines(
dtype, is_small=is_batch_only, receipt=receipt
)
for mode in ["batch"] if is_batch_only else MODES:
if allowed_modes is not None and mode not in allowed_modes:
continue
for spec in dq_extra_specs:
mm = _MASK_MAP.get(spec.mask, spec.mask)
mb = _BIAS_MAP.get(spec.bias, spec.bias)
if allowed_masks is not None and mm not in allowed_masks:
continue
if allowed_biases is not None and mb not in allowed_biases:
continue
ww = bwd_dq_wave_warp(tile, hq, trload=(tag == "trload"))
configs.append(
FmhaKernelConfig(
family="bwd_dq_dk_dv",
data_type=dtype,
mode=mode,
hdim_q=hq,
hdim_v=hv,
pipeline="qr",
tile_m0=tile[0],
tile_n0=tile[1],
tile_k0=tile[2],
tile_n1=tile[3] if len(tile) > 3 else hv,
tile_k1=tile[4] if len(tile) > 4 else tile[2],
tile_k0max=tile[5] if len(tile) > 5 else hq,
tile_bwd6=tile[6] if len(tile) > 6 else 0,
tile_bwd7=tile[7] if len(tile) > 7 else 0,
tile_bwd8=tile[8] if len(tile) > 8 else 0,
tile_tag=tag,
use_trload=(tag == "trload"),
pad_s=_pad_val(spec.spad),
pad_sk=_pad_val(spec.skpad),
pad_d=_pad_val(spec.dpad),
pad_dv=_pad_val(spec.dvpad),
mask=mm,
bias=mb,
dbias=(spec.dbias == "t"),
dropout=(spec.dropout != "no"),
dropout_variant=spec.dropout,
deterministic=(spec.deterministic == "t"),
gfx_arch=arch,
**ww,
)
)
# convert_dq
for hd in BWD_CONVERT_DQ_HDIMS:
cvt_specs = get_bwd_convert_dq_pipelines(dtype, hd)
n_tile_groups = BWD_CONVERT_DQ_TILE_GROUPS.get(hd, 1)
for mode in MODES:
if allowed_modes is not None and mode not in allowed_modes:
continue
for spec in cvt_specs:
if mode == "group" and spec.spad != "t":
continue
for tile_grp in range(n_tile_groups):
configs.append(
FmhaKernelConfig(
family="bwd_convert_dq",
data_type=dtype,
mode=mode,
hdim_q=hd,
hdim_v=hd,
pipeline="qr",
tile_m0=64,
tile_tag=f"g{tile_grp}" if tile_grp > 0 else "",
pad_s=_pad_val(spec.spad),
pad_d=_pad_val(spec.dpad),
deterministic=(spec.deterministic == "t"),
gfx_arch=arch,
)
)
return configs
# =============================================================================
# Filter utility
# =============================================================================
def apply_filter(
configs: List[FmhaKernelConfig], expr: str = "", filter_file: str = ""
) -> List[FmhaKernelConfig]:
"""Apply user-defined filters to a config list."""
result = configs
if filter_file:
import importlib.util
spec = importlib.util.spec_from_file_location("user_filter", filter_file)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
fn = getattr(mod, "filter_config")
result = [c for c in result if fn(c)]
if expr:
result = [c for c in result if eval(expr, {"c": c})] # noqa: S307
return result
# =============================================================================
# CLI
# =============================================================================
def main():
parser = argparse.ArgumentParser(description="FMHA instance enumeration")
parser.add_argument("config", help="Sweep config JSON")
parser.add_argument("--arch", default="gfx950")
parser.add_argument("--receipt", type=int, default=0)
parser.add_argument(
"--filter",
dest="filter_expr",
default="",
help='Python expression per config, e.g. "c.hdim_q == 128"',
)
parser.add_argument(
"--filter-file",
default="",
help="Path to .py file with filter_config(c) -> bool",
)
parser.add_argument("--list", action="store_true")
parser.add_argument("--count-only", action="store_true")
args = parser.parse_args()
configs = expand_sweep(args.config, args.arch, args.receipt)
before = len(configs)
configs = apply_filter(configs, args.filter_expr, args.filter_file)
filtered = before - len(configs)
print(
f"Expanded {args.config} -> {before} configs"
f"{f' (filtered {filtered}, kept {len(configs)})' if filtered else ''}"
)
if args.count_only:
return
if args.list:
for i, c in enumerate(configs):
print(f" [{i}] {c.name}")
if __name__ == "__main__":
main()