[rocm-libraries] ROCm/rocm-libraries#5260 (commit a1834d2)

[CK] [CK_Tile] Add FMHA scaffolding to CK kernel dispatcher (#5260)

## Motivation

The CK Tile dispatcher currently supports GEMM and Grouped Convolution
but has no support for Fused Multi-Head Attention (FMHA). The
example/ck_tile/01_fmha folder contains a comprehensive FMHA
implementation with forward, backward, split-KV, paged-KV, append-KV,
and batch-prefill kernels across multiple GPU architectures — but there
is no unified dispatch layer for it. This PR ports the FMHA stack into
the dispatcher, following the same architectural patterns established by
GEMM and Grouped Convolution, enabling runtime kernel selection, JIT
compilation from Python, and a declarative C++ example flow. Autotuning
heuristics to follow.

## Technical Details

This PR adds FMHA scaffolding to the CK dispatcher framework, mirroring
GEMM's layered architecture. Seven new C++ runtime headers provide type
definitions (coexisting with upstream headers via __has_include,
requiring zero modifications to example/ck_tile/01_fmha/), a problem
builder with 18+ setters, Signature + Algorithm kernel key matching, a
virtual kernel instance, a DECL_FMHA_KERNEL_SET macro with wildcard
support and named tile/wave/warp setters, arch-aware registry with JSON
export, and a dispatcher with seqtune-aware selection, configurable
timing, and multi-stage execution plans for split-KV (two-stage) and
backward (three-stage). The codegen pipeline is driven by a
fmha_arch_specs.json capturing per-arch tile tables and pipeline
constraints for five architectures (gfx90a/942/950/1100/1201), migrated
from hardcoded logic in 01_fmha/codegen/, with supporting modules for
C++ symbol mappings, validation rules, and named receipt profiles
(ck_default, flash, pytorch, aiter, fp32, fp8). Python integration
(fmha_utils.py) mirrors the C++ layer with JIT compilation, parallel
multi-kernel builds, HIP memory management via ctypes, tolerance-based
validation, and a NumPy CPU reference with GQA support. Twenty-seven C++
and thirty-two Python examples cover the full feature surface — forward,
split-KV, masks, bias, dropout, GQA, backward, append-KV, batch prefill,
fp8, logits soft cap, sink tokens, and parameter sweeps — all
JIT-compiled on the fly.

## Test Plan

Seven test files cover the runtime types, codegen, and end-to-end
correctness. C++ unit tests validate the problem builder, dispatcher
planning (single-stage for forward/paged-KV/append-KV; multi-stage for
split-KV and backward), registry operations, and the kernel-set
declaration macro. Python unit tests verify codegen emission, profile
filtering, and 15 validation rules for masks, hdim constraints, and
pipeline requirements. GPU execution validation in 01_basic_fmha
--validate reports zero errors across 65,536 elements with max absolute
error of 7.29e-05. A gold-standard parity suite (test_fmha_parity.py)
runs 14 configurations through both the upstream tile_example_fmha_fwd
and the dispatcher, comparing exit codes to confirm behavioral parity —
all 14 match.

## Test Result

The C++ smoke test builds and passes all 9 compiled examples, and a
Python JIT sweep (29_sweep_seqlen.py) passes 7/7 configurations reaching
up to 375 TFLOPS at seqlen 2048.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Yaswanth Raparti <113389104+yraparti@users.noreply.github.com>
Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com>
Co-authored-by: Maksim (Max) Podkorytov <Maksim.Podkorytov@amd.com>
Co-authored-by: yashagar <yashagar@amd.com>
This commit is contained in:
Vidyasagar Ananthan
2026-05-17 00:29:40 -07:00
committed by GitHub
parent cc5c79a1e7
commit b20458e19e
148 changed files with 41250 additions and 87 deletions

View File

@@ -6,6 +6,7 @@ include_directories(BEFORE
${CMAKE_CURRENT_LIST_DIR}/ops
)
add_subdirectory(ops/fmha EXCLUDE_FROM_ALL)
add_subdirectory(ops/gemm EXCLUDE_FROM_ALL)
add_subdirectory(ops/gemm_streamk EXCLUDE_FROM_ALL)
add_subdirectory(ops/pooling EXCLUDE_FROM_ALL)

View File

@@ -16,7 +16,7 @@
| GEMM | grouped_gemm_quant | | ❌ | | ❌ | | | | ❌ | | | | ❌ | ❌ | ❌ | ❌ |
| Reduce | multi_reduce2d [8]<br>engine: reduce/<br>example: 05_reduce/ | ✅ | | ❌ | | | | | | | | | ❌ | ✅ | ✅ | ❌ |
| Reduce | reduce2d<br>example: 05_reduce/ | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
| Attention | fmha<br>example: 01_fmha/ | | | | ❌ | | | | | | | | | | | ❌ |
| Attention | fmha<br>engine: fmha/<br>example: 01_fmha/ | | | | ❌ | | | | | | | | | | | ❌ |
| Attention | sparse_attn<br>example: 50_sparse_attn/ | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
| Activation | softmax | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
| Activation | topk_softmax<br>example: 09_topk_softmax/ | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |

View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Generic multi-GPU parallel job runner for tile engine benchmarks.
Op-agnostic: takes opaque jobs, distributes them across GPUs with one
job per GPU at a time, and yields results in completion order. Used by
fmha_benchmark.py and reusable for gemm/reduce/pooling benchmarks.
"""
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Iterator, List, Optional, Tuple
def run_parallel_on_gpus(
jobs: List[Any],
gpu_ids: List[int],
run_one: Callable[[Any, int], Any],
max_workers: Optional[int] = None,
) -> Iterator[Tuple[int, Any]]:
"""Dispatch jobs across GPUs, one job per GPU at a time.
Args:
jobs: Opaque job objects passed to run_one.
gpu_ids: GPU IDs to use (e.g. [0,1,2,3]). At most one job per GPU runs concurrently.
run_one: Callable run_one(job, gpu_id) -> result. Caller is responsible
for any subprocess isolation, environment setup, etc.
max_workers: Thread pool size. Defaults to len(gpu_ids).
Yields:
(job_index, result) tuples in completion order. Caller can sort by
job_index to restore submission order if needed.
"""
if not jobs:
return
if max_workers is None:
max_workers = len(gpu_ids)
# One job per GPU at a time
gpu_semas = {gid: threading.Semaphore(1) for gid in gpu_ids}
cycle = [0]
cycle_lock = threading.Lock()
def _pick_gpu() -> int:
with cycle_lock:
gid = gpu_ids[cycle[0] % len(gpu_ids)]
cycle[0] += 1
return gid
def _wrapper(job_idx: int, job: Any) -> Tuple[int, Any]:
gid = _pick_gpu()
gpu_semas[gid].acquire()
try:
return job_idx, run_one(job, gid)
finally:
gpu_semas[gid].release()
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [pool.submit(_wrapper, i, j) for i, j in enumerate(jobs)]
for fut in as_completed(futures):
yield fut.result()

3
tile_engine/ops/fmha/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
*.log
build/
*_build*/

View File

@@ -0,0 +1,94 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# FMHA Tile Engine -- Pure Python benchmarking via the CK dispatcher.
# No C++ per-kernel targets; all compilation is JIT via the dispatcher.
set(FMHA_TE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(FMHA_TE_CONFIGS ${FMHA_TE_DIR}/configs)
include(ProcessorCount)
ProcessorCount(NPROC)
if(NPROC EQUAL 0)
set(NPROC 8)
endif()
# Use first arch from SUPPORTED_GPU_TARGETS, or fallback to gfx950
set(FMHA_BENCH_ARCH "gfx950")
if(SUPPORTED_GPU_TARGETS)
list(GET SUPPORTED_GPU_TARGETS 0 FMHA_BENCH_ARCH)
endif()
# Main benchmark target (runs forward sweep by default)
add_custom_target(benchmark_fmha
COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py
${FMHA_TE_CONFIGS}/fwd.json
--arch ${FMHA_BENCH_ARCH}
--workers ${NPROC}
--best
--json ${CMAKE_CURRENT_BINARY_DIR}/fmha_fwd_results.json
WORKING_DIRECTORY ${FMHA_TE_DIR}
COMMENT "FMHA tile engine benchmark (forward)"
)
if(TARGET ck_tile_dispatcher)
add_dependencies(benchmark_fmha ck_tile_dispatcher)
endif()
# Per-variant convenience targets
foreach(variant fwd bwd splitkv appendkv pagedkv batch_prefill)
if(EXISTS ${FMHA_TE_CONFIGS}/${variant}.json)
add_custom_target(benchmark_fmha_${variant}
COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py
${FMHA_TE_CONFIGS}/${variant}.json
--arch ${FMHA_BENCH_ARCH}
--workers ${NPROC}
--best
--json ${CMAKE_CURRENT_BINARY_DIR}/fmha_${variant}_results.json
WORKING_DIRECTORY ${FMHA_TE_DIR}
COMMENT "FMHA tile engine benchmark (${variant})"
)
if(TARGET ck_tile_dispatcher)
add_dependencies(benchmark_fmha_${variant} ck_tile_dispatcher)
endif()
endif()
endforeach()
# CI target (minimal sweep for quick validation)
if(EXISTS ${FMHA_TE_CONFIGS}/fwd_ci.json)
add_custom_target(benchmark_fmha_ci
COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py
${FMHA_TE_CONFIGS}/fwd_ci.json
--arch ${FMHA_BENCH_ARCH}
--workers 8
--verify
WORKING_DIRECTORY ${FMHA_TE_DIR}
COMMENT "FMHA tile engine CI benchmark"
)
if(TARGET ck_tile_dispatcher)
add_dependencies(benchmark_fmha_ci ck_tile_dispatcher)
endif()
endif()
# All-variants target
set(FMHA_ALL_CONFIGS "")
foreach(cfg fwd bwd splitkv appendkv pagedkv batch_prefill)
if(EXISTS ${FMHA_TE_CONFIGS}/${cfg}.json)
list(APPEND FMHA_ALL_CONFIGS ${FMHA_TE_CONFIGS}/${cfg}.json)
endif()
endforeach()
add_custom_target(benchmark_fmha_all
COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py
${FMHA_ALL_CONFIGS}
--arch ${FMHA_BENCH_ARCH}
--workers ${NPROC}
--best
--json ${CMAKE_CURRENT_BINARY_DIR}/fmha_all_results.json
WORKING_DIRECTORY ${FMHA_TE_DIR}
COMMENT "FMHA tile engine benchmark (all variants)"
)
if(TARGET ck_tile_dispatcher)
add_dependencies(benchmark_fmha_all ck_tile_dispatcher)
endif()

View File

@@ -0,0 +1,192 @@
# FMHA Tile Engine
Benchmarking and kernel enumeration for Fused Multi-Head Attention (FMHA) via the CK dispatcher's pipelined JIT compilation.
Covers all 9 FMHA kernel families: Forward, Split-KV (main + combine), Paged-KV, Append-KV, Batch Prefill, and Backward (dot\_do\_o, dq\_dk\_dv, convert\_dq) -- totaling 33,541 unique kernel specializations on gfx950.
## Directory Layout
```
fmha/
fmha_instance_builder.py Kernel enumeration from JSON config + pipeline rules
fmha_benchmark.py Single-config JIT compile and GPU benchmark
fmha_full_benchmark.py Full sweep: compile all kernels, benchmark across test shapes
ck_fmha_testing_matrix.yaml Test shapes (smoke / full / nightly)
CMakeLists.txt CMake targets
README.md This file
configs/ Sweep definitions (JSON)
receipt0_fwd.json Full receipt-0 forward: ~12K kernels
fwd.json Forward variants
fwd_ci.json Minimal CI subset
bwd.json Backward variants
splitkv.json Split-KV
appendkv.json Append-KV
pagedkv.json Paged-KV
batch_prefill.json Batch prefill
filters/ Sample Python filter scripts
h128_no_dropout.py Keep only h128 without dropout
```
## Quick Start
```bash
# Count kernels without compiling
python fmha_instance_builder.py configs/receipt0_fwd.json --count-only
# Minimal CI build + run (~16 kernels, <1 min)
python fmha_benchmark.py configs/fwd_ci.json --workers 128 --verify
# Full forward receipt-0 compile-only (12K kernels, ~10 min with 256 workers)
python fmha_benchmark.py configs/receipt0_fwd.json --workers 256 --compile-only
# Full sweep: compile every fwd kernel, benchmark against all smoke shapes
python fmha_full_benchmark.py --category smoke --variant fwd --workers 256
# Quick end-to-end test (2 kernels, 1 shape)
python fmha_full_benchmark.py --category smoke --variant fwd --max-kernels 2 --workers 4
```
## How It Works
### Kernel Enumeration
```
JSON config (variant + trait_config allow-list)
--> fmha_instance_builder.py
--> fmha_pipeline_rules.py (self-contained CK parity logic)
--> fmha_arch_specs.json (tile tables per arch / dtype / hdim)
--> list of FmhaKernelConfig (33,541 total on gfx950)
--> optional --filter / --filter-file
```
The pipeline rules in `dispatcher/codegen/fmha_pipeline_rules.py` reproduce the exact kernel enumeration from CK Tile's `01_fmha/codegen/`, including per-arch tile constraints, pipeline selection, padding variants, and feature products. Parity is verified by `dispatcher/tests/validate_arch_specs_parity.py`.
### Benchmark Tools
**`fmha_benchmark.py`** -- single-config benchmark. Input: one JSON config (kernel definitions). JIT-compiles all matching kernels, runs each on a given problem size, reports per-kernel timing and optional CPU validation. Optionally writes `--csv` output.
**`fmha_full_benchmark.py`** -- full sweep benchmark. Input: `ck_fmha_testing_matrix.yaml` (test shapes) + JSON configs (kernel definitions). Compiles all kernel variants for selected families, then iterates over test shapes, matching each shape to compatible compiled kernels and benchmarking every match. Writes `--csv` and `--json` output.
### JIT Compilation Pipeline
Both tools use the dispatcher's `setup_multiple_fmha_dispatchers()` which implements a 3-stage pipelined build:
1. **Codegen** (parallel) -- generate C++ kernel specializations and ctypes wrappers
2. **Compile** (parallel) -- `hipcc` compile each kernel and ctypes lib
3. **Link + Load** (parallel) -- produce `.so` libraries, load via ctypes
With 256 workers, throughput is roughly 5-10 kernels/sec depending on kernel complexity.
## JSON Config Format
Each config specifies a `variant` and an optional `trait_config` that acts as an allow-list filter:
```json
{
"variant": "fwd",
"trait_config": {
"data_type": {"values": ["fp16", "bf16"]},
"pipeline": {"values": ["qr_async"]},
"mode": {"values": ["batch"]},
"mask": {"values": ["no"]},
"bias": {"values": ["no"]},
"lse": {"values": [false]},
"dropout": {"values": [false]},
"logits": {"values": [false]},
"sink": {"values": [false]}
}
}
```
If a trait key is absent, all values pass. The `receipt0_fwd.json` config only restricts `data_type` to exclude fp32, giving the full ~12K forward kernel set.
## Filtering
### CLI expression
```bash
python fmha_benchmark.py configs/receipt0_fwd.json \
--filter "c.hdim_q == 128 and c.pipeline == 'qr_async'"
python fmha_full_benchmark.py --variant fwd \
--filter "c.hdim_q == 128 and c.hdim_v == 128 and c.data_type == 'fp16'"
```
The expression accesses `c` (an `FmhaKernelConfig` dataclass) with fields: `data_type`, `mode`, `hdim_q`, `hdim_v`, `pipeline`, `tile_m0`, `tile_n0`, `tile_k0`, `pad_s`, `pad_sk`, `pad_d`, `pad_dv`, `mask`, `bias`, `lse`, `dropout`, `logits`, `sink`, `skip_min_seqlen_q`, `qscale`, `paged_kv`, `rope`, `deterministic`, `dbias`, `dropout_variant`.
### Python file filter
```bash
python fmha_benchmark.py configs/receipt0_fwd.json --filter-file filters/h128_no_dropout.py
```
The file must define `filter_config(c) -> bool`. Both `--filter` and `--filter-file` combine with AND logic.
## Test Shape Matrix
`ck_fmha_testing_matrix.yaml` defines test problems in three tiers:
| Category | Purpose | Shapes |
|----------|---------|--------|
| `smoke` | Pre-submit sanity, <5 min | ~365 |
| `full` | Post-submit validation | smoke + ~1,500 |
| `nightly`| Exhaustive sweep | all |
Shapes cover representative configurations: GQA ratios, asymmetric head dims, non-power-of-2 sequences, FP8 variants, long sequences, and cross-attention patterns.
## Output Format
### CSV
```
problem_name,batch,seqlen_q,seqlen_k,nhead_q,nhead_k,hdim_q,hdim_v,dtype,
kernel,family,mode,pipeline,tile_m0,tile_n0,tile_k0,...,
latency_ms,tflops,bandwidth_gb_s
```
Every column needed to fully reconstruct the kernel identity is included. TFLOPS and latency come directly from CK's internal HIP event timing.
### JSON
```json
{
"metadata": {
"arch": "gfx950",
"category": "smoke",
"total_kernels": 600,
"shapes_benchmarked": 42,
"total_measurements": 12600
},
"results": [...]
}
```
## CMake Targets
```bash
make benchmark_fmha # Forward sweep
make benchmark_fmha_ci # Quick CI validation
make benchmark_fmha_bwd # Backward sweep
make benchmark_fmha_all # All variants
make benchmark_fmha_splitkv # Split-KV only
```
## Parity Verification
```bash
python dispatcher/tests/validate_arch_specs_parity.py --arch gfx950 --receipt 0
# PASS: 33,541 kernels across all 9 families
```
This confirms the dispatcher's self-contained enumeration exactly matches CK Tile's upstream codegen.
## Example: Single-Shape All-Kernel Benchmark
Run every compiled fwd fp16 h128 kernel against one shape:
```bash
python fmha_full_benchmark.py \
--category smoke --variant fwd --workers 256 \
--filter "c.hdim_q == 128 and c.hdim_v == 128 and c.data_type == 'fp16'" \
--csv results.csv
```

View File

@@ -0,0 +1,788 @@
test_categories:
Smoke:
description: "Pre-submit sanity checks. Fast execution, covering basic functionality and edge cases."
test_patterns:
- "*/Smoke.*"
labels: ["Smoke"]
Full:
description: "Post-submit validation. Comprehensive coverage of modern LLM architectures and CK operational constraints."
test_patterns:
- "*/Smoke.*"
- "*/Full.*"
labels: ["Full"]
Nightly:
description: "Nightly exhaustive coverage. Sweeps all combinations of precision, layout, masking, and padding."
test_patterns:
- "*"
labels: ["Nightly"]
execution_settings:
default_timeout: 60
category_timeouts:
Smoke: 60 # 1 min per test
Full: 300 # 5 min per test
Nightly: 600 # 10 min per test
# =============================================================================
# Forward Pass (Prefill) & Stochastic Execution (Dropout)
# =============================================================================
forward_tests:
# ---------------------------------------------------------------------------
# Smoke Tests (Fast, representative subset)
# ---------------------------------------------------------------------------
smoke:
- name: "GQA_4to1_Prefill_Basic"
description: "Baseline GQA prefill; primary optimization target."
batch: [1, 4]
seqlen_q: [2048]
seqlen_k: [2048]
nhead_q: [32]
nhead_k: [8]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16", "bf16"]
layout: ["BHSD"]
mask: ["no_mask", "top_left"]
bias: ["none"]
dropout: [0.0]
lse: [false, true]
- name: "Small_GQA_7to1_SubWarp"
description: "Sub-warp vectorized loads; low LDS utilization bounds."
batch: [1]
seqlen_q: [1024]
seqlen_k: [1024]
nhead_q: [14]
nhead_k: [2]
hdim_q: [64]
hdim_v: [64]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "MHA_H96_Irregular_Dim"
description: "Non-power-of-2 hdim; forces complex padding/striding in LDS."
batch: [2]
seqlen_q: [1024]
seqlen_k: [1024]
nhead_q: [32]
nhead_k: [32]
hdim_q: [96]
hdim_v: [96]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["top_left"]
bias: ["none"]
dropout: [0.0]
lse: [false]
# CK smoke test edge cases (from example/ck_tile/01_fmha/script/smoke_test_fwd.sh)
- name: "CK_Asymmetric_Hdim_Small"
description: "Asymmetric hdim_q != hdim_v; tests vectorized load widths."
batch: [2]
seqlen_q: [55]
seqlen_k: [256]
nhead_q: [2]
nhead_k: [1]
hdim_q: [16]
hdim_v: [32, 64, 128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "CK_Tiny_Sequences"
description: "Edge cases: sq=1, sq=3, very short sequences."
batch: [1, 2]
seqlen_q: [1, 3, 33]
seqlen_k: [10, 99, 33]
nhead_q: [2]
nhead_k: [1]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask", "top_left"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "CK_Asymmetric_Seqlen"
description: "Asymmetric seqlen_q != seqlen_k from CK smoke tests."
batch: [1, 2]
seqlen_q: [100, 99, 1024]
seqlen_k: [51, 256, 256]
nhead_q: [3]
nhead_k: [3]
hdim_q: [64, 128]
hdim_v: [64, 128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask", "top_left"]
bias: ["none"]
dropout: [0.0]
lse: [false]
# Hdim sweep covering all supported (hdim_q, hdim_v) pairs.
# YAML cartesian product creates some orphan combos (hdim_q != hdim_v pairs
# without kernels). The benchmark silently skips these. Use --validate to list them.
# Supported pairs: h32, h64, h80x96, h96, h96x128, h128, h160, h192x128, h192, h256
- name: "CK_All_Hdim_Sweep"
description: "Sweep all supported hdim combos. Orphan pairs are skipped at runtime."
batch: [2]
seqlen_q: [1024]
seqlen_k: [1024]
nhead_q: [8]
nhead_k: [4]
hdim_q: [32, 64, 80, 96, 128, 160, 192, 256]
hdim_v: [32, 64, 96, 128, 160, 192, 256]
dtype: ["fp16", "bf16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "CK_FP8_Basic"
description: "FP8 basic forward test."
batch: [1, 2]
seqlen_q: [128]
seqlen_k: [128]
nhead_q: [1]
nhead_k: [1]
hdim_q: [64, 128, 192, 256]
hdim_v: [64, 128, 128, 256]
dtype: ["fp8bf16", "fp8fp32"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
lse: [false]
# Production model configs (from aiter model_shapes.json)
- name: "GQA_16to1_Large"
description: "16:1 GQA ratio (70B-class models)."
batch: [1, 4]
seqlen_q: [2048]
seqlen_k: [2048]
nhead_q: [64]
nhead_k: [4]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16", "bf16"]
layout: ["BHSD"]
mask: ["no_mask", "top_left"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "MQA_128to8_Decode"
description: "405B-class decode: 128 Q heads, 8 KV heads, single token query."
batch: [1, 8, 64]
seqlen_q: [1]
seqlen_k: [1024, 4096]
nhead_q: [128]
nhead_k: [8]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16", "bf16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "MLA_Sparse_Decode"
description: "Multi-latent attention decode (R1-class): asymmetric h192x128."
batch: [1, 4]
seqlen_q: [1]
seqlen_k: [1024, 4096]
nhead_q: [128]
nhead_k: [128]
hdim_q: [192]
hdim_v: [128]
dtype: ["fp16", "bf16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "Vision_Transformer_Shapes"
description: "Vision-text hybrid (Maverick-class): h88 and h128 mixed."
batch: [1, 4]
seqlen_q: [256, 1024]
seqlen_k: [256, 1024]
nhead_q: [16, 40]
nhead_k: [8, 16]
hdim_q: [88, 128]
hdim_v: [88, 128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "FP8_Varlen_Realistic"
description: "FP8 with realistic GQA and variable lengths (from aiter tests)."
batch: [1, 8]
seqlen_q: [113, 256, 1024]
seqlen_k: [203, 512, 1024]
nhead_q: [8, 32, 40]
nhead_k: [1, 8]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp8bf16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "Extreme_GQA_Ratios"
description: "Extreme GQA: 5:1, 10:1, 24:4, 48:8 from aiter test suite."
batch: [2]
seqlen_q: [1024]
seqlen_k: [1024]
nhead_q: [5, 10, 24, 48]
nhead_k: [1, 1, 4, 8]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask", "top_left"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "Paged_Decode_Shapes"
description: "Paged attention decode patterns: single-token Q, long KV context."
batch: [4, 80, 128]
seqlen_q: [1, 4]
seqlen_k: [512, 4096]
nhead_q: [8, 16, 64]
nhead_k: [1, 4]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "Prefill_Odd_Lengths"
description: "Prefill with non-standard seq lengths from aiter test suite."
batch: [2]
seqlen_q: [113, 339, 799, 1023, 3131]
seqlen_k: [203, 339, 799, 1024, 3131]
nhead_q: [32]
nhead_k: [8]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["top_left"]
bias: ["none"]
dropout: [0.0]
lse: [false]
# ---------------------------------------------------------------------------
# Full Tests (Modern LLM Architectures & CK Constraints)
# ---------------------------------------------------------------------------
full:
- name: "MHA_H256_High_LDS_Pressure"
description: "High LDS pressure; tests block partitioner limits with hdim=256."
batch: [1, 4]
seqlen_q: [4096]
seqlen_k: [4096]
nhead_q: [8]
nhead_k: [4]
hdim_q: [256]
hdim_v: [256]
dtype: ["bf16"]
layout: ["BHSD", "BSHD"]
mask: ["no_mask", "top_left"]
bias: ["none"]
dropout: [0.0]
lse: [true]
- name: "MQA_64to1_Broadcast"
description: "Pure MQA; tests extreme KV to Q broadcast logic (64:1)."
batch: [2]
seqlen_q: [4096]
seqlen_k: [4096]
nhead_q: [64]
nhead_k: [1]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["top_left"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "GQA_6to1_Irregular"
description: "Irregular 6:1 GQA ratio; tests tile distribution."
batch: [2]
seqlen_q: [4096]
seqlen_k: [4096]
nhead_q: [48]
nhead_k: [8]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "MLA_H128xH576_Asymmetric"
description: "Multi-latent attention fusion; asymmetric Q/KV (128 vs 576)."
batch: [1, 4]
seqlen_q: [4096]
seqlen_k: [4096]
nhead_q: [128]
nhead_k: [128]
hdim_q: [128]
hdim_v: [576]
dtype: ["fp16", "bf16"]
layout: ["BHSD"]
mask: ["top_left"]
bias: ["none"]
dropout: [0.0]
lse: [true]
- name: "Asymmetric_Head_Dims_192_128"
description: "Test asymmetric head dimensions (192x128)."
batch: [2]
seqlen_q: [2048]
seqlen_k: [2048]
nhead_q: [16]
nhead_k: [16]
hdim_q: [192]
hdim_v: [128]
dtype: ["fp16", "bf16"]
layout: ["BHSD", "BSHD"]
mask: ["no_mask", "top_left"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "Asymmetric_Head_Dims_128_192"
description: "Test asymmetric head dimensions (128x192)."
batch: [2]
seqlen_q: [2048]
seqlen_k: [2048]
nhead_q: [16]
nhead_k: [16]
hdim_q: [128]
hdim_v: [192]
dtype: ["fp16", "bf16"]
layout: ["BHSD"]
mask: ["top_left"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "Diverse_Head_Dims_Sweep"
description: "Sweep across various head dimensions to ensure broad coverage."
batch: [2]
seqlen_q: [1024]
seqlen_k: [1024]
nhead_q: [16]
nhead_k: [16]
hdim_q: [48, 64, 72, 96, 128, 160, 256]
hdim_v: [48, 64, 72, 96, 128, 160, 256]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "Stochastic_Execution_Dropout_Sweep"
description: "PRNG state synchronization and warp alignment with stochastic masking across dims."
batch: [4]
seqlen_q: [1024]
seqlen_k: [1024]
nhead_q: [16]
nhead_k: [8]
hdim_q: [48, 64, 72, 96, 128, 160, 256]
hdim_v: [48, 64, 72, 96, 128, 160, 256]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["top_left"]
bias: ["none"]
dropout: [0.1, 0.2]
lse: [false, true]
- name: "Padding_Boundary_Stress_Odd_Lengths"
description: "Test sequences that are not perfect multiples of the tile size to validate padding logic."
batch: [2]
seqlen_q: [259, 500, 987, 1023]
seqlen_k: [259, 500, 987, 1023]
nhead_q: [16]
nhead_k: [16]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask", "top_left"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "Bias_Variants_Sweep"
description: "Test elementwise and alibi bias across different sequence lengths and batch sizes."
batch: [1, 4]
seqlen_q: [512, 1024]
seqlen_k: [512, 1024]
nhead_q: [16]
nhead_k: [16]
hdim_q: [64, 128]
hdim_v: [64, 128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask", "top_left"]
bias: ["elementwise", "alibi"]
dropout: [0.0]
lse: [false]
- name: "Extreme_Batch_Size_Stress"
description: "Test very large batch sizes to stress grid launch dimensions and scheduling."
batch: [64, 128, 256]
seqlen_q: [128]
seqlen_k: [128]
nhead_q: [8]
nhead_k: [8]
hdim_q: [64]
hdim_v: [64]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
lse: [false]
- name: "Long_Sequence_Stress"
description: "Test very long sequences (approaching split-KV territory but forced dense)."
batch: [1]
seqlen_q: [8192, 16384]
seqlen_k: [8192, 16384]
nhead_q: [16]
nhead_k: [4]
hdim_q: [128]
hdim_v: [128]
dtype: ["bf16"]
layout: ["BHSD"]
mask: ["top_left"]
bias: ["none"]
dropout: [0.0]
lse: [true]
- name: "Cross_Attention_Shapes"
description: "Test shapes typical of cross-attention where seqlen_q != seqlen_k."
batch: [2]
seqlen_q: [1, 32, 128]
seqlen_k: [1024, 4096]
nhead_q: [16]
nhead_k: [16]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
- name: "CK_Benchmark_Standard"
description: "Standard CK benchmark sweep (from benchmark_fwd.sh)."
batch: [32, 16, 8, 4, 2, 1]
seqlen_q: [512, 1024, 2048, 4096, 8192, 16384]
seqlen_k: [512, 1024, 2048, 4096, 8192, 16384]
nhead_q: [32, 16, 8]
nhead_k: [32, 16, 8]
hdim_q: [64, 128, 256]
hdim_v: [64, 128, 256]
dtype: ["fp16", "bf16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
- name: "CK_Benchmark_V3_Large"
description: "V3 pipeline benchmark with very long sequences (from benchmark_fwd_v3.sh)."
batch: [1]
seqlen_q: [16384, 37200, 65536]
seqlen_k: [16384, 37200, 65536]
nhead_q: [16, 40, 64]
nhead_k: [1, 16, 40, 64]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16", "bf16"]
layout: ["BHSD"]
mask: ["no_mask", "top_left"]
bias: ["none"]
dropout: [0.0]
lse: [false]
# =============================================================================
# Backward Pass (Gradient Computation)
# =============================================================================
backward_tests:
# ---------------------------------------------------------------------------
# Smoke Tests
# ---------------------------------------------------------------------------
smoke:
- name: "Bwd_Basic_No_Features"
description: "Basic backward pass without optional features."
batch: [1, 2]
seqlen_q: [512]
seqlen_k: [512]
nhead_q: [16]
nhead_k: [16]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16", "bf16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
has_dbias: [false]
is_deterministic: [false]
- name: "Bwd_GQA_Smoke"
description: "Backward GQA smoke test (4:1 and 8:1 ratios)."
batch: [2]
seqlen_q: [1024]
seqlen_k: [1024]
nhead_q: [32]
nhead_k: [8]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask", "top_left"]
bias: ["none"]
dropout: [0.0]
has_dbias: [false]
is_deterministic: [false]
- name: "Bwd_Hdim_Sweep_Smoke"
description: "Backward across key head dimensions."
batch: [2]
seqlen_q: [512]
seqlen_k: [512]
nhead_q: [8]
nhead_k: [8]
hdim_q: [64, 96, 128, 256]
hdim_v: [64, 96, 128, 256]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
has_dbias: [false]
is_deterministic: [false]
- name: "Bwd_With_Mask_Dropout"
description: "Backward with causal mask and dropout."
batch: [2]
seqlen_q: [512]
seqlen_k: [512]
nhead_q: [16]
nhead_k: [16]
hdim_q: [64, 128]
hdim_v: [64, 128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["top_left"]
bias: ["none"]
dropout: [0.1]
has_dbias: [false]
is_deterministic: [false]
- name: "Bwd_Asymmetric_Hdim_Smoke"
description: "Backward with asymmetric head dimensions."
batch: [2]
seqlen_q: [512]
seqlen_k: [512]
nhead_q: [16]
nhead_k: [16]
hdim_q: [192]
hdim_v: [128]
dtype: ["fp16", "bf16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
has_dbias: [false]
is_deterministic: [false]
# ---------------------------------------------------------------------------
# Full Tests
# ---------------------------------------------------------------------------
full:
- name: "Bwd_GQA_Support"
description: "Backward pass with Grouped Query Attention."
batch: [2]
seqlen_q: [1024]
seqlen_k: [1024]
nhead_q: [32, 64]
nhead_k: [8]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["top_left"]
bias: ["none"]
dropout: [0.0]
has_dbias: [false]
is_deterministic: [false]
- name: "Bwd_High_Capacity_H256"
description: "Backward pass with hdim=256; high LDS pressure."
batch: [1]
seqlen_q: [1024]
seqlen_k: [1024]
nhead_q: [8]
nhead_k: [4]
hdim_q: [256]
hdim_v: [256]
dtype: ["bf16"]
layout: ["BHSD"]
mask: ["no_mask", "top_left"]
bias: ["none"]
dropout: [0.0]
has_dbias: [false]
is_deterministic: [false]
- name: "Bwd_Irregular_H96"
description: "Backward pass with non-power-of-2 hdim."
batch: [2]
seqlen_q: [1024]
seqlen_k: [1024]
nhead_q: [32]
nhead_k: [32]
hdim_q: [96]
hdim_v: [96]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["top_left"]
bias: ["none"]
dropout: [0.0]
has_dbias: [false]
is_deterministic: [false]
- name: "Bwd_All_Features_Enabled"
description: "Backward pass with bias gradients, dropout, and deterministic accumulation."
batch: [2]
seqlen_q: [512]
seqlen_k: [512]
nhead_q: [16]
nhead_k: [16]
hdim_q: [48, 64, 72, 96, 128, 160, 256]
hdim_v: [48, 64, 72, 96, 128, 160, 256]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["top_left"]
bias: ["elementwise", "alibi"]
dropout: [0.1]
has_dbias: [true]
is_deterministic: [true]
- name: "Bwd_Padding_Boundary_Stress"
description: "Test backward pass with sequences that are not perfect multiples of the tile size."
batch: [1]
seqlen_q: [259, 500, 1023]
seqlen_k: [259, 500, 1023]
nhead_q: [8]
nhead_k: [8]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask", "top_left"]
bias: ["none"]
dropout: [0.0]
has_dbias: [false]
is_deterministic: [false]
- name: "Bwd_Asymmetric_Head_Dims_192_128"
description: "Test backward pass with asymmetric head dimensions (192x128)."
batch: [2]
seqlen_q: [1024]
seqlen_k: [1024]
nhead_q: [16]
nhead_k: [16]
hdim_q: [192]
hdim_v: [128]
dtype: ["fp16", "bf16"]
layout: ["BHSD"]
mask: ["top_left"]
bias: ["none"]
dropout: [0.0]
has_dbias: [false]
is_deterministic: [false]
- name: "Bwd_Asymmetric_Head_Dims_128_192"
description: "Test backward pass with asymmetric head dimensions (128x192)."
batch: [2]
seqlen_q: [1024]
seqlen_k: [1024]
nhead_q: [16]
nhead_k: [16]
hdim_q: [128]
hdim_v: [192]
dtype: ["fp16", "bf16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
has_dbias: [false]
is_deterministic: [false]
- name: "Bwd_Diverse_Head_Dims_Sweep"
description: "Sweep backward pass across various head dimensions."
batch: [2]
seqlen_q: [512]
seqlen_k: [512]
nhead_q: [16]
nhead_k: [16]
hdim_q: [48, 64, 72, 96, 128, 160, 256]
hdim_v: [48, 64, 72, 96, 128, 160, 256]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
has_dbias: [false]
is_deterministic: [false]
- name: "Bwd_Cross_Attention_Shapes"
description: "Test shapes typical of cross-attention where seqlen_q != seqlen_k in backward."
batch: [2]
seqlen_q: [1, 32, 128]
seqlen_k: [1024, 4096]
nhead_q: [16]
nhead_k: [16]
hdim_q: [128]
hdim_v: [128]
dtype: ["fp16"]
layout: ["BHSD"]
mask: ["no_mask"]
bias: ["none"]
dropout: [0.0]
has_dbias: [false]
is_deterministic: [false]

View File

@@ -0,0 +1,6 @@
{
"variant": "appendkv",
"trait_config": {
"data_type": {"values": ["fp16", "bf16", "fp8"]}
}
}

View File

@@ -0,0 +1,6 @@
{
"variant": "batch_prefill",
"trait_config": {
"data_type": {"values": ["fp16", "bf16", "fp8bf16"]}
}
}

View File

@@ -0,0 +1,6 @@
{
"variant": "bwd",
"trait_config": {
"data_type": {"values": ["fp16", "bf16"]}
}
}

View File

@@ -0,0 +1,9 @@
{
"variant": "fwd",
"trait_config": {
"data_type": {"values": ["fp16", "bf16"]},
"pipeline": {"values": ["qr", "qr_async"]},
"mask": {"values": ["no", "top_left"]},
"bias": {"values": ["no"]}
}
}

View File

@@ -0,0 +1,14 @@
{
"variant": "fwd",
"trait_config": {
"data_type": {"values": ["fp16"]},
"pipeline": {"values": ["qr_async"]},
"mask": {"values": ["no"]},
"bias": {"values": ["no"]},
"mode": {"values": ["batch"]},
"lse": {"values": [false]},
"dropout": {"values": [false]},
"logits": {"values": [false]},
"sink": {"values": [false]}
}
}

View File

@@ -0,0 +1,6 @@
{
"variant": "pagedkv",
"trait_config": {
"data_type": {"values": ["fp16", "bf16", "fp8"]}
}
}

View File

@@ -0,0 +1,6 @@
{
"variant": "fwd",
"trait_config": {
"data_type": {"values": ["fp16", "bf16", "fp8bf16", "fp8fp32"]}
}
}

View File

@@ -0,0 +1,6 @@
{
"variant": "splitkv",
"trait_config": {
"data_type": {"values": ["fp16", "bf16", "fp8"]}
}
}

View File

@@ -0,0 +1,14 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Sample filter: only h128 kernels without dropout.
Usage:
python fmha_benchmark.py configs/receipt0_fwd.json --filter-file filters/h128_no_dropout.py
python fmha_instance_builder.py configs/receipt0_fwd.json --filter-file filters/h128_no_dropout.py --count-only
"""
def filter_config(c) -> bool:
"""Keep only h128 kernels without dropout."""
return c.hdim_q == 128 and not c.dropout

View File

@@ -0,0 +1,939 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
FMHA tile engine benchmark runner.
Uses the dispatcher's setup_multiple_fmha_dispatchers() for pipelined JIT
compilation, then runs GPU benchmarks and reports results.
Usage:
python fmha_benchmark.py configs/fwd.json
python fmha_benchmark.py configs/fwd.json --workers 256 --build-dir /tmp/fmha_build
python fmha_benchmark.py configs/fwd.json --problems "2,8,1024,128" --verify
"""
import argparse
import csv
import json
import os
import shutil
import sys
import time
from pathlib import Path
from typing import List
import numpy as np
_DISPATCHER_ROOT = Path(__file__).resolve().parents[3] / "dispatcher"
sys.path.insert(0, str(_DISPATCHER_ROOT / "python"))
sys.path.insert(0, str(_DISPATCHER_ROOT / "codegen"))
from fmha_utils import ( # noqa: E402
FmhaProblem,
FmhaRunner,
cpu_attention_fwd,
detect_gpu_arch,
setup_multiple_fmha_dispatchers,
)
from fmha.instance_gen import expand_sweep, apply_filter # noqa: E402
# Reusable multi-GPU job dispatcher (op-agnostic)
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "common"))
from parallel_runner import run_parallel_on_gpus # noqa: E402
def _compute_result(
config,
prob,
time_ms,
output,
ref,
is_causal,
ns,
api_family,
dtype_tol,
gpu_id=None,
):
"""Compute tflops, max_err, status and build result dict + display line.
Returns (result_dict, display_line) or None if time_ms is None/0.
"""
tflops = prob.num_ops / (time_ms * 1e-3) / 1e12 if time_ms > 0 else 0
if is_causal and time_ms > 0:
sq, sk = prob.seqlen_q, prob.seqlen_k
causal_ratio = (min(sq, sk) + 1) / (2.0 * sk)
tflops = prob.num_ops * causal_ratio / (time_ms * 1e-3) / 1e12
max_err = 0.0
status = "OK"
if ref is not None and output is not None:
max_err = float(np.abs(output.astype(np.float32) - ref).max())
atol, rtol = dtype_tol
tol = atol + rtol * np.abs(ref).max()
status = "PASS" if max_err < tol else "FAIL"
splits_tag = f" [ns={ns}]" if api_family == "splitkv" else ""
display_name = f"{config.name}{splits_tag}"
gpu_tag = f" [GPU{gpu_id}]" if gpu_id is not None else ""
display_line = (
f" {display_name:<105} {time_ms:>10.3f}"
f" {tflops:>10.2f} {max_err:>10.2e} {status:>6}{gpu_tag}"
)
result_dict = {
"kernel": config.name,
"dtype": config.data_type,
"hdim_q": config.hdim_q,
"hdim_v": config.hdim_v,
"pipeline": config.pipeline,
"mode": config.mode,
"mask": config.mask,
"bias": config.bias,
"tile_m0": config.tile_m0,
"tile_n0": config.tile_n0,
"tile_k0": config.tile_k0,
"tile_n1": config.tile_n1,
"tile_k1": config.tile_k1,
"tile_k0max": config.tile_k0max,
"warp_m0": config.warp_m0,
"warp_n0": config.warp_n0,
"warp_k0": config.warp_k0,
"block_per_cu": config.block_per_cu,
"num_splits": ns if api_family == "splitkv" else None,
"problem": {
"batch": prob.batch,
"nhead_q": prob.nhead_q,
"nhead_k": prob.nhead_k,
"seqlen_q": prob.seqlen_q,
"seqlen_k": prob.seqlen_k,
"hdim_q": prob.hdim_q,
"hdim_v": prob.hdim_v,
},
"latency_ms": time_ms,
"tflops": tflops,
"max_err": max_err,
"status": status,
}
return result_dict, display_line
def _run_kernel_isolated(
lib_path, arch, prob, run_kwargs, data_dir, gpu_id=0, timeout=120
):
"""Run a single kernel in a subprocess. Returns (time_ms, output_path) or (None, error_msg).
Survives GPU faults — if the subprocess crashes, returns an error instead of killing main.
"""
import json as _json
import subprocess as sp
# Write a small runner script that the subprocess will execute.
# Use json.dumps for string values to safely escape quotes/backslashes in paths.
_lib = _json.dumps(str(lib_path))
_arch = _json.dumps(str(arch))
_pydir = _json.dumps(str(_DISPATCHER_ROOT / "python"))
_ddir = _json.dumps(str(data_dir))
script = f'''
import sys, os, numpy as np
os.environ["HIP_VISIBLE_DEVICES"] = "{gpu_id}"
sys.path.insert(0, {_pydir})
from fmha_utils import FmhaRunner, FmhaProblem
runner = FmhaRunner.from_library({_lib}, {_arch})
_d = {_ddir}
Q = np.load(os.path.join(_d, "Q.npy"))
K = np.load(os.path.join(_d, "K.npy"))
V = np.load(os.path.join(_d, "V.npy"))
prob = FmhaProblem(batch={prob.batch}, nhead_q={prob.nhead_q}, nhead_k={prob.nhead_k},
seqlen_q={prob.seqlen_q}, seqlen_k={prob.seqlen_k},
hdim_q={prob.hdim_q}, hdim_v={prob.hdim_v})
result = runner.run(Q, K, V, prob, **{run_kwargs!r})
if result.success:
np.save(os.path.join(_d, "O.npy"), result.output)
print(f"TIME={{result.time_ms}}")
else:
print("FAIL")
runner.cleanup()
'''
script_path = os.path.join(data_dir, "run_kernel.py")
with open(script_path, "w") as f:
f.write(script)
try:
r = sp.run(
[sys.executable, script_path],
capture_output=True,
text=True,
timeout=timeout,
env={**os.environ, "HIP_VISIBLE_DEVICES": str(gpu_id)},
)
if r.returncode != 0:
err = r.stderr[-200:] if r.stderr else f"exit code {r.returncode}"
return None, None, f"CRASH: {err.strip()}"
# Parse time from stdout
for line in r.stdout.strip().split("\n"):
if line.startswith("TIME="):
time_ms = float(line[5:])
out_path = os.path.join(data_dir, "O.npy")
output = np.load(out_path) if os.path.exists(out_path) else None
return time_ms, output, None
return None, None, "No TIME output"
except sp.TimeoutExpired:
return None, None, "TIMEOUT"
except Exception as e:
return None, None, str(e)
def parse_problems(spec: str) -> List[FmhaProblem]:
"""Parse problem specs: 'batch,nhead,seqlen,hdim;...'"""
problems = []
for part in spec.split(";"):
vals = [int(x) for x in part.split(",")]
if len(vals) == 4:
b, h, s, d = vals
problems.append(
FmhaProblem(
batch=b,
nhead_q=h,
nhead_k=h,
seqlen_q=s,
seqlen_k=s,
hdim_q=d,
hdim_v=d,
)
)
elif len(vals) == 6:
b, hq, hk, sq, sk, d = vals
problems.append(
FmhaProblem(
batch=b,
nhead_q=hq,
nhead_k=hk,
seqlen_q=sq,
seqlen_k=sk,
hdim_q=d,
hdim_v=d,
)
)
return problems
def main():
parser = argparse.ArgumentParser(description="FMHA Tile Engine Benchmark")
parser.add_argument(
"configs", nargs="*", help="Sweep config JSON(s) (optional for exhaustive)"
)
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument(
"--workers", type=int, default=os.cpu_count() or 8, help="Parallel JIT workers"
)
parser.add_argument(
"--problems",
default="2,8,1024,128",
help="Problem sizes: batch,nhead,seqlen,hdim",
)
parser.add_argument(
"--no-verify", action="store_true", help="Skip CPU reference verification"
)
parser.add_argument(
"--best", action="store_true", help="Show best kernel per problem"
)
parser.add_argument(
"--csv",
type=str,
default=None,
help="CSV output path (default: <build-dir>/results.csv). Use --no-csv to disable.",
)
parser.add_argument("--no-csv", action="store_true", help="Disable CSV output")
parser.add_argument("--json", type=str, default=None)
parser.add_argument(
"--log",
type=str,
default=None,
help="Path to detailed log file (compilation status, failures, timings)",
)
parser.add_argument(
"--build-dir",
type=str,
default=str(Path(__file__).resolve().parent / "build"),
help="JIT build output directory",
)
parser.add_argument("--clean", action="store_true")
parser.add_argument("--compile-only", action="store_true")
parser.add_argument(
"--filter",
dest="filter_expr",
default="",
help='Python expr per config, e.g. "c.hdim_q == 128"',
)
parser.add_argument(
"--filter-file", default="", help="Path to .py with filter_config(c) -> bool"
)
parser.add_argument(
"--tiles",
choices=["rules", "exhaustive"],
default="rules",
help="Tile enumeration mode: 'rules' (default) uses constraint-based generation; "
"'exhaustive' brute-forces ALL compilable tiles (like the oracle)",
)
parser.add_argument(
"--num-splits",
default="1,2,4,8",
help="Comma-separated num_splits values to sweep for splitkv (default: 1,2,4,8)",
)
parser.add_argument(
"--isolate",
action="store_true",
help="Run each kernel in a subprocess to survive GPU faults (slower but fault-tolerant)",
)
parser.add_argument(
"--gpus",
type=str,
default=None,
help="Comma-separated GPU IDs to use for parallel benchmarking (e.g. '0,1,2,3'). "
"Implies --isolate. Each GPU runs one kernel at a time.",
)
args = parser.parse_args()
# --gpus implies --isolate
if args.gpus:
args.isolate = True
gpu_ids = [int(x) for x in args.gpus.split(",")] if args.gpus else [0]
problems = parse_problems(args.problems)
num_splits_list = [int(x) for x in args.num_splits.split(",")]
build_dir = Path(args.build_dir).resolve()
if args.clean and build_dir.exists():
print(f" Cleaning {build_dir} ...")
shutil.rmtree(build_dir)
build_dir.mkdir(parents=True, exist_ok=True)
# Phase 0: Expand configs
all_configs = []
restrict_hdims = sorted({(p.hdim_q, p.hdim_v) for p in problems})
if args.tiles == "exhaustive":
# Exhaustive mode: all tiles (no constraint filter) × full feature cross-product.
# JSON config is optional — if provided, its trait_config scopes the sweep.
cfg_path = args.configs[0] if args.configs else None
all_configs = expand_sweep(
cfg_path,
args.arch,
0,
mode="exhaustive",
restrict_hdims=restrict_hdims,
)
print(
f" Exhaustive: {len(all_configs)} total combos (all tiles × all features)"
)
else:
if not args.configs:
parser.error(
"Config JSON(s) required for rules mode. Use --tiles exhaustive to run without."
)
for cfg_path in args.configs:
configs = expand_sweep(
cfg_path,
args.arch,
0,
mode="rules",
restrict_hdims=restrict_hdims,
)
all_configs.extend(configs)
print(f" {cfg_path}: {len(configs)} kernel configs")
if args.filter_expr or args.filter_file:
before = len(all_configs)
all_configs = apply_filter(all_configs, args.filter_expr, args.filter_file)
print(f" Filter: {before} -> {len(all_configs)} configs")
# Remove standalone combine configs -- they are auto-paired during JIT
all_configs = [c for c in all_configs if c.family != "fwd_splitkv_combine"]
print(f"\n{'=' * 70}")
print("FMHA Tile Engine Benchmark")
print(f"{'=' * 70}")
print(f" Arch: {args.arch}")
print(f" Kernels: {len(all_configs)}")
print(f" Problems: {len(problems)}")
print(f" Workers: {args.workers}")
print(f" Build: {build_dir}")
# Phase 1: Pipelined JIT via the dispatcher
print(
f"\n--- Phase 1: JIT compile ({len(all_configs)} kernels,"
f" {args.workers} workers) ---"
)
jit_t0 = time.perf_counter()
def _progress(stage, done, total):
elapsed = time.perf_counter() - jit_t0
pct = done * 100 // total
print(
f"\r [{stage}] {done}/{total} ({pct}%) - {elapsed:.0f}s",
end="",
flush=True,
)
if done == total:
print()
setups = setup_multiple_fmha_dispatchers(
all_configs,
output_dir=build_dir,
verbose=True,
max_workers=args.workers,
progress_callback=_progress,
)
jit_time = time.perf_counter() - jit_t0
built = sum(1 for s in setups if s.success)
failed = len(all_configs) - built
print(f"\n Built {built}/{len(all_configs)} in {jit_time:.0f}s ({failed} failed)")
# Load runners for successfully compiled kernels
for setup in setups:
if setup.success and setup.library_path and setup.runner is None:
try:
setup.runner = FmhaRunner.from_library(setup.library_path, args.arch)
except Exception as e:
print(f" Warning: Failed to load runner: {e}")
setup.success = False
if args.compile_only:
print(f"\n{'=' * 70}")
print(f" Compile-only mode. {built}/{len(all_configs)} kernels compiled.")
if failed > 0:
print("\n Failed kernels:")
for cfg, s in zip(all_configs, setups):
if not s.success:
err = (s.error or "unknown")[:80]
print(f" {cfg.name}: {err}")
if args.tiles == "exhaustive":
# Oracle-style analysis: find tiles missed by rules vs compilable
from fmha.instance_gen import validate_tile, FmhaTileConfig # noqa: E402
missed = []
for cfg, s in zip(all_configs, setups):
if s.success:
tile = FmhaTileConfig(
bm0=cfg.tile_m0,
bn0=cfg.tile_n0,
bk0=cfg.tile_k0,
bn1=cfg.tile_n1,
bk1=cfg.tile_k1,
bk0max=cfg.tile_k0max,
rm0=cfg.wave_m0,
rn0=1,
rk0=1,
rm1=cfg.wave_m1,
rn1=1,
rk1=1,
wm0=cfg.warp_m0,
wn0=cfg.warp_n0,
wk0=cfg.warp_k0,
wm1=cfg.warp_m1,
wn1=cfg.warp_n1,
wk1=cfg.warp_k1,
)
if not validate_tile(
tile,
args.arch,
cfg.data_type,
cfg.hdim_q,
cfg.hdim_v,
cfg.pipeline,
):
missed.append(cfg)
if missed:
print(
f"\n MISSED by rules ({len(missed)} tiles compile but rules reject):"
)
seen = set()
for cfg in missed:
key = (cfg.tile_m0, cfg.tile_n0, cfg.tile_k0)
if key not in seen:
seen.add(key)
print(
f" ({cfg.tile_m0:>3}, {cfg.tile_n0:>3}, {cfg.tile_k0:>3})"
)
else:
print(
"\n Rules are COMPLETE — all compilable tiles are generated by rules."
)
print(f"{'=' * 70}")
return
# Phase 2: Benchmark
print(f"\n--- Phase 2: Benchmark ({built} kernels x {len(problems)} problems) ---")
dtype_map = {
"fp16": np.float16,
"bf16": np.float32,
"fp32": np.float32,
"fp8": np.float16,
"fp8bf16": np.float16,
"fp8fp32": np.float16,
"bf8": np.float16,
"mxfp8": np.float16,
"mxfp4": np.float16,
}
# Tolerance per dtype: (atol, rtol)
_DTYPE_TOL = {
"fp16": (1e-3, 1e-3),
"bf16": (1e-2, 1e-2),
"fp32": (1e-5, 1e-5),
"fp8": (16.0, 0.0),
"fp8bf16": (16.0, 0.0),
"fp8fp32": (16.0, 0.0),
"bf8": (16.0, 0.0),
"mxfp8": (16.0, 0.0),
"mxfp4": (32.0, 0.0),
}
np.random.seed(42)
all_results = []
bench_t0 = time.perf_counter()
for prob_idx, prob in enumerate(problems):
first_dtype = all_configs[0].data_type if all_configs else "fp16"
first_mask = all_configs[0].mask if all_configs else "no"
np_dtype = dtype_map.get(first_dtype, np.float16)
dtype_tol = _DTYPE_TOL.get(first_dtype, (1e-2, 1e-2))
# Use uniform [0, 1] like CK example (default 'uf' mode) -- produces
# peaked softmax distributions that actually test kernel correctness.
# randn*0.1 makes softmax nearly uniform for large hdim, hiding bugs.
Q = np.random.uniform(0, 1, prob.q_shape()).astype(np_dtype)
K = np.random.uniform(0, 1, prob.k_shape()).astype(np_dtype)
V = np.random.uniform(0, 1, prob.v_shape()).astype(np_dtype)
_MASK_INT = {"no": 0, "top_left": 1, "bottom_right": 2, "generic": 3}
first_mask_int = _MASK_INT.get(first_mask, 0)
ref = None
if not args.no_verify:
# For bf16: truncate inputs to bf16 precision before computing reference,
# so reference sees the SAME data the kernel sees (after bf16 encoding).
if first_dtype == "bf16":
from fmha_utils import _float32_to_bf16, _bf16_to_float32
Q_ref = _bf16_to_float32(_float32_to_bf16(Q.astype(np.float32)))
K_ref = _bf16_to_float32(_float32_to_bf16(K.astype(np.float32)))
V_ref = _bf16_to_float32(_float32_to_bf16(V.astype(np.float32)))
else:
Q_ref = Q.astype(np.float32)
K_ref = K.astype(np.float32)
V_ref = V.astype(np.float32)
ref = cpu_attention_fwd(
Q_ref,
K_ref,
V_ref,
prob.scale,
mask_type=first_mask_int,
)
h_str = (
f"H={prob.nhead_q}"
if prob.nhead_q == prob.nhead_k
else f"Hq={prob.nhead_q} Hk={prob.nhead_k}"
)
s_str = (
f"S={prob.seqlen_q}"
if prob.seqlen_q == prob.seqlen_k
else f"Sq={prob.seqlen_q} Sk={prob.seqlen_k}"
)
prob_str = f"B={prob.batch} {h_str} {s_str} D={prob.hdim_q}"
print(f"\n Problem [{prob_idx}]: {prob_str}")
print(
f" {'Kernel':<105} {'Time(ms)':>10} {'TFLOPS':>10}"
f" {'MaxErr':>10} {'Status':>6}"
)
print(f" {'-' * 145}")
_BIAS_INT = {"no": 0, "bias": 1, "alibi": 2}
# Build list of (config, setup, run_kwargs, ns) jobs for benchmarking
bench_jobs = []
for config, setup in zip(all_configs, setups):
if not setup.success:
continue
if not args.isolate and setup.runner is None:
continue
if config.hdim_q != prob.hdim_q or config.hdim_v != prob.hdim_v:
continue
mask_int = _MASK_INT.get(config.mask, 0)
is_causal = config.mask in ("top_left", "bottom_right")
is_group = config.mode == "group"
_FAMILY_TO_API = {
"fwd_splitkv": "splitkv",
"fwd_pagedkv": "pagedkv",
"fwd_appendkv": "appendkv",
}
api_family = _FAMILY_TO_API.get(config.family, config.family)
splits_to_try = num_splits_list if api_family == "splitkv" else [0]
for ns in splits_to_try:
run_kwargs = dict(
mask_type=mask_int,
bias_type=_BIAS_INT.get(config.bias, 0),
has_lse=int(config.lse),
has_dropout=int(config.dropout),
has_logits=int(config.logits),
has_sink=int(config.sink),
data_type=config.data_type,
is_group_mode=int(is_group),
is_v_rowmajor=int(config.vlayout == "r"),
api_family=api_family,
window_left=-1,
window_right=0 if is_causal else -1,
)
if api_family == "splitkv":
run_kwargs["num_splits"] = ns
bench_jobs.append(
(config, setup, run_kwargs, ns, api_family, is_causal)
)
if args.isolate and len(gpu_ids) > 1:
# ---- Multi-GPU parallel isolated execution ----
import tempfile
# Save input data once, shared by all subprocesses
shared_data_dir = tempfile.mkdtemp(prefix="fmha_shared_")
np.save(os.path.join(shared_data_dir, "Q.npy"), Q)
np.save(os.path.join(shared_data_dir, "K.npy"), K)
np.save(os.path.join(shared_data_dir, "V.npy"), V)
def _run_one(job, gpu_id):
config, setup, run_kwargs, ns, api_family, is_causal = job
# Per-job output dir (unique per subprocess)
job_dir = tempfile.mkdtemp(prefix=f"fmha_gpu{gpu_id}_")
# Symlink shared inputs instead of copying
for fname in ("Q.npy", "K.npy", "V.npy"):
os.symlink(
os.path.join(shared_data_dir, fname),
os.path.join(job_dir, fname),
)
time_ms, output, err = _run_kernel_isolated(
setup.library_path, args.arch, prob, run_kwargs, job_dir, gpu_id
)
shutil.rmtree(job_dir, ignore_errors=True)
return (config, time_ms, output, err, ns, api_family, is_causal, gpu_id)
print(f" Running {len(bench_jobs)} kernels across {len(gpu_ids)} GPUs ...")
for _, result in run_parallel_on_gpus(bench_jobs, gpu_ids, _run_one):
config, time_ms, output, err, ns, api_family, is_causal, gpu_id = result
if err:
splits_tag = f" [ns={ns}]" if api_family == "splitkv" else ""
print(
f" {config.name}{splits_tag:<105} {'---':>10} {'---':>10} {'---':>10} GPU{gpu_id} {err[:15]}"
)
continue
r, line = _compute_result(
config,
prob,
time_ms,
output,
ref,
is_causal,
ns,
api_family,
dtype_tol,
gpu_id,
)
print(line)
all_results.append(r)
shutil.rmtree(shared_data_dir, ignore_errors=True)
else:
# ---- Sequential execution (in-process or single-GPU isolated) ----
for config, setup, run_kwargs, ns, api_family, is_causal in bench_jobs:
time_ms = None
output = None
if args.isolate:
import tempfile
data_dir = tempfile.mkdtemp(prefix="fmha_run_")
np.save(os.path.join(data_dir, "Q.npy"), Q)
np.save(os.path.join(data_dir, "K.npy"), K)
np.save(os.path.join(data_dir, "V.npy"), V)
time_ms, output, err = _run_kernel_isolated(
setup.library_path,
args.arch,
prob,
run_kwargs,
data_dir,
gpu_ids[0],
)
shutil.rmtree(data_dir, ignore_errors=True)
if err:
print(
f" {config.name:<105} {'---':>10} {'---':>10} {'---':>10} {err[:20]:>6}"
)
continue
else:
result = setup.runner.run(Q, K, V, prob, **run_kwargs)
if not result.success:
continue
time_ms = result.time_ms
output = result.output
r, line = _compute_result(
config,
prob,
time_ms,
output,
ref,
is_causal,
ns,
api_family,
dtype_tol,
)
print(line)
all_results.append(r)
bench_time = time.perf_counter() - bench_t0
# Cleanup
for setup in setups:
if setup.success and setup.runner:
try:
setup.runner.cleanup()
except Exception:
pass
# Report
print(f"\n{'=' * 70}")
print(f" JIT: {jit_time:.0f}s ({built} kernels)")
print(f" Benchmark: {bench_time:.1f}s")
print(f" Results: {len(all_results)} measurements")
if all_results:
from collections import defaultdict
by_problem = defaultdict(list)
for r in all_results:
key = json.dumps(r["problem"], sort_keys=True)
by_problem[key].append(r)
print("\n Best kernel per problem:")
for key, results in by_problem.items():
best = max(results, key=lambda x: x["tflops"])
prob = json.loads(key)
ns_tag = f" [ns={best['num_splits']}]" if best.get("num_splits") else ""
h_str = (
f"H={prob['nhead_q']}"
if prob["nhead_q"] == prob["nhead_k"]
else f"Hq={prob['nhead_q']} Hk={prob['nhead_k']}"
)
s_str = (
f"S={prob['seqlen_q']}"
if prob["seqlen_q"] == prob["seqlen_k"]
else f"Sq={prob['seqlen_q']} Sk={prob['seqlen_k']}"
)
print(
f" B={prob['batch']} {h_str}"
f" {s_str} D={prob['hdim_q']}"
f" -> {best['kernel']}{ns_tag}"
f" ({best['tflops']:.2f} TFLOPS, {best['latency_ms']:.3f} ms)"
)
# CSV output: default to <build-dir>/results.csv; merge with existing file
# keeping the faster result (higher tflops) for duplicate kernel+problem keys.
_CSV_FIELDS = [
"kernel",
"dtype",
"pipeline",
"mode",
"mask",
"bias",
"hdim_q",
"hdim_v",
"tile_m0",
"tile_n0",
"tile_k0",
"tile_n1",
"tile_k1",
"tile_k0max",
"warp_m0",
"warp_n0",
"warp_k0",
"block_per_cu",
"num_splits",
"batch",
"nhead_q",
"nhead_k",
"seqlen_q",
"seqlen_k",
"latency_ms",
"tflops",
"max_err",
"status",
]
csv_path = args.csv if args.csv else str(build_dir / "results.csv")
if not args.no_csv and all_results:
# Build map of new results keyed by (kernel, problem-tuple)
def _csv_key(row):
p = row["problem"] if "problem" in row else row
return (
row["kernel"],
row.get("num_splits", 0),
p.get("batch"),
p.get("nhead_q"),
p.get("nhead_k"),
p.get("seqlen_q"),
p.get("seqlen_k"),
p.get("hdim_q"),
p.get("hdim_v"),
)
# Load existing CSV if present
existing = {}
if os.path.isfile(csv_path):
with open(csv_path, "r", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
# Convert numeric fields back from strings
for k in row:
if k in ("latency_ms", "tflops", "max_err"):
try:
row[k] = float(row[k])
except (ValueError, TypeError):
pass
elif k in (
"hdim_q",
"hdim_v",
"tile_m0",
"tile_n0",
"tile_k0",
"tile_n1",
"tile_k1",
"tile_k0max",
"warp_m0",
"warp_n0",
"warp_k0",
"block_per_cu",
"num_splits",
"batch",
"nhead_q",
"nhead_k",
"seqlen_q",
"seqlen_k",
):
try:
row[k] = int(row[k])
except (ValueError, TypeError):
pass
key = _csv_key(row)
existing[key] = row
# Merge new results — keep whichever is faster
for r in all_results:
row = {**r, **r["problem"]}
del row["problem"]
key = _csv_key(r)
prev = existing.get(key)
if prev is None or float(row.get("tflops", 0)) > float(
prev.get("tflops", 0)
):
existing[key] = row
# Write merged + sorted CSV
merged = sorted(
existing.values(), key=lambda x: float(x.get("tflops", 0)), reverse=True
)
with open(csv_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=_CSV_FIELDS, extrasaction="ignore")
writer.writeheader()
for row in merged:
writer.writerow(row)
print(f"\n CSV: {csv_path} ({len(merged)} rows, sorted by tflops)")
if args.json:
report = {
"metadata": {
"arch": args.arch,
"jit_time_s": jit_time,
"bench_time_s": bench_time,
"num_kernels": len(all_configs),
"num_built": built,
"num_problems": len(problems),
},
"results": all_results,
}
with open(args.json, "w") as f:
json.dump(report, f, indent=2)
print(f" JSON: {args.json}")
if args.log:
from datetime import datetime
with open(args.log, "w") as lf:
lf.write(f"FMHA Benchmark Log - {datetime.now().isoformat()}\n")
lf.write(f"{'=' * 80}\n\n")
lf.write(f"Command: {' '.join(sys.argv)}\n")
lf.write(f"Arch: {args.arch}\n")
lf.write(f"Tiles mode: {args.tiles}\n")
lf.write(f"Workers: {args.workers}\n")
lf.write(f"Build dir: {build_dir}\n")
lf.write(f"Total configs: {len(all_configs)}\n")
lf.write(f"Built: {built}\n")
lf.write(f"Failed: {failed}\n")
lf.write(f"JIT time: {jit_time:.1f}s\n")
lf.write(f"Bench time: {bench_time:.1f}s\n")
lf.write(f"Problems: {[str(p) for p in problems]}\n\n")
# All configs attempted
lf.write(f"{'=' * 80}\n")
lf.write(f"ALL CONFIGS ({len(all_configs)})\n")
lf.write(f"{'=' * 80}\n\n")
for i, (cfg, setup) in enumerate(zip(all_configs, setups)):
status = "OK" if setup.success else "FAILED"
lf.write(f"[{i:4d}] {status:6s} {cfg.name}\n")
lf.write(
f" tile=({cfg.tile_m0},{cfg.tile_n0},{cfg.tile_k0},{cfg.tile_n1},{cfg.tile_k1},{cfg.tile_k0max})"
f" warp=({cfg.warp_m0},{cfg.warp_n0},{cfg.warp_k0})"
f" bpc={cfg.block_per_cu}\n"
)
if not setup.success and setup.error:
lf.write(f" error: {setup.error}\n")
lf.write("\n")
# Failed configs summary
lf.write(f"\n{'=' * 80}\n")
lf.write(f"FAILED CONFIGS ({failed})\n")
lf.write(f"{'=' * 80}\n\n")
for cfg, setup in zip(all_configs, setups):
if not setup.success:
lf.write(f" {cfg.name}\n")
if setup.error:
lf.write(f" {setup.error}\n")
# Benchmark results
if all_results:
lf.write(f"\n{'=' * 80}\n")
lf.write(f"BENCHMARK RESULTS ({len(all_results)} measurements)\n")
lf.write(f"{'=' * 80}\n\n")
sorted_results = sorted(all_results, key=lambda x: -x["tflops"])
for r in sorted_results:
p = r["problem"]
lf.write(
f" {r['tflops']:8.2f} TFLOPS {r['latency_ms']:8.3f} ms"
f" B={p['batch']} H={p['nhead_q']} S={p['seqlen_q']} D={p['hdim_q']}"
f" {r['kernel']}\n"
)
print(f" Log: {args.log}")
print(f"{'=' * 70}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,689 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Full FMHA benchmark sweep.
JIT-compiles FMHA kernels, then for EACH test shape finds all matching
kernels and benchmarks them, streaming results incrementally to CSV/JSON.
Results are printed live per-shape with the best kernel highlighted.
TFLOPS and latency come directly from CK's HIP event timing.
Usage:
# Full sweep
python fmha_full_benchmark.py --workers 256
# Quick end-to-end test
python fmha_full_benchmark.py --category smoke --variant fwd --max-kernels 10 --workers 4
# Filter to h128 fp16
python fmha_full_benchmark.py --filter "c.hdim_q == 128 and c.data_type == 'fp16'"
"""
import argparse
import csv
import itertools
import json
import os
import subprocess
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
import yaml
import numpy as np
_THIS_DIR = Path(__file__).resolve().parent
_DISPATCHER_ROOT = _THIS_DIR.parents[2] / "dispatcher"
sys.path.insert(0, str(_DISPATCHER_ROOT / "python"))
sys.path.insert(0, str(_DISPATCHER_ROOT / "codegen"))
sys.path.insert(0, str(_THIS_DIR))
from fmha_utils import ( # noqa: E402
detect_gpu_arch,
setup_multiple_fmha_dispatchers,
)
from fmha.instance_gen import expand_sweep, apply_filter # noqa: E402
YAML_PATH = _THIS_DIR / "ck_fmha_testing_matrix.yaml"
VARIANT_CONFIGS = {
"fwd": "configs/receipt0_fwd.json",
"splitkv": "configs/splitkv.json",
"pagedkv": "configs/pagedkv.json",
"appendkv": "configs/appendkv.json",
"batch_prefill": "configs/batch_prefill.json",
"bwd": "configs/bwd.json",
}
# Variant -> YAML section mapping. KV-cache variants use forward_tests shapes.
VARIANT_YAML_SECTIONS = {
"fwd": ["forward_tests"],
"splitkv": ["forward_tests"],
"pagedkv": ["forward_tests"],
"appendkv": ["forward_tests"],
"batch_prefill": ["forward_tests"],
"bwd": ["backward_tests"],
}
DTYPE_CK = {"fp16": "fp16", "bf16": "bf16", "fp8bf16": "fp8bf16", "fp8fp32": "fp8fp32"}
DTYPE_NP = {
"fp16": np.float16,
"bf16": np.float16,
"fp32": np.float32,
"fp8bf16": np.float16,
"fp8fp32": np.float16,
}
ELEM_BYTES = {"fp16": 2, "bf16": 2, "fp32": 4, "fp8bf16": 1, "fp8fp32": 1}
MASK_INT = {"no": 0, "top_left": 1, "generic": 3}
BIAS_INT = {"no": 0, "bias": 1, "alibi": 2}
KV_LAYOUT_INT = {"vectorized": 0, "linear": 1}
KV_LOOKUP_INT = {"vllm": 0, "sglang": 1}
@dataclass
class TestShape:
name: str
category: str
variant: str
batch: int
seqlen_q: int
seqlen_k: int
nhead_q: int
nhead_k: int
hdim_q: int
hdim_v: int
dtype: str
mask: str = "no_mask"
bias: str = "none"
dropout: float = 0.0
lse: bool = False
def parse_yaml(
yaml_path: Path, category: str = "smoke", sections: Optional[List[str]] = None
) -> List[TestShape]:
with open(yaml_path) as f:
data = yaml.safe_load(f)
shapes = []
cats = ["smoke"]
if category in ("full", "nightly"):
cats.append("full")
if category == "nightly":
cats.append("nightly")
section_variant_map = [("forward_tests", "fwd"), ("backward_tests", "bwd")]
if sections:
section_variant_map = [(s, v) for s, v in section_variant_map if s in sections]
for section, variant in section_variant_map:
if section not in data:
continue
for cat in cats:
for test in data[section].get(cat, []):
for combo in itertools.product(
test.get("batch", [1]),
test.get("seqlen_q", [1024]),
test.get("seqlen_k", [1024]),
test.get("nhead_q", [16]),
test.get("nhead_k", [16]),
test.get("hdim_q", [128]),
test.get("hdim_v", [128]),
test.get("dtype", ["fp16"]),
test.get("mask", ["no_mask"]),
test.get("bias", ["none"]),
test.get("dropout", [0.0]),
test.get("lse", [False]),
):
b, sq, sk, hq, hk, dq, dv, dt, m, bi, dr, ls = combo
shapes.append(
TestShape(
test["name"],
cat,
variant,
b,
sq,
sk,
hq,
hk,
dq,
dv,
dt,
mask=m,
bias=bi,
dropout=dr,
lse=ls,
)
)
return shapes
def bandwidth_gb_s(shape: TestShape, latency_ms: float) -> float:
if latency_ms <= 0:
return 0.0
eb = ELEM_BYTES.get(shape.dtype, 2)
total = (
shape.batch
* (
shape.nhead_q * shape.seqlen_q * shape.hdim_q
+ shape.nhead_k * shape.seqlen_k * shape.hdim_q
+ shape.nhead_k * shape.seqlen_k * shape.hdim_v
+ shape.nhead_q * shape.seqlen_q * shape.hdim_v
)
* eb
)
return total / (latency_ms * 1e6)
FAMILY_TO_API = {
"fwd": "fwd",
"fwd_splitkv": "splitkv",
"fwd_splitkv_combine": "splitkv",
"fwd_pagedkv": "pagedkv",
"fwd_appendkv": "appendkv",
"batch_prefill": "batch_prefill",
"bwd_dot_do_o": "bwd",
"bwd_dq_dk_dv": "bwd",
"bwd_convert_dq": "bwd",
}
def _config_to_serializable(config, so_path: str) -> dict:
"""Convert FmhaKernelConfig + so_path to a picklable dict for subprocess."""
return {
"so_path": so_path,
"api_family": FAMILY_TO_API.get(config.family, "fwd"),
"data_type": config.data_type,
"kernel": config.name,
"family": config.family,
"mode": config.mode,
"pipeline": config.pipeline,
"tile_m0": config.tile_m0,
"tile_n0": config.tile_n0,
"tile_k0": config.tile_k0,
"tile_n1": config.tile_n1,
"tile_k1": config.tile_k1,
"tile_k0max": config.tile_k0max,
"pad_s": config.pad_s,
"pad_sk": config.pad_sk,
"pad_d": config.pad_d,
"pad_dv": config.pad_dv,
"mask": config.mask,
"bias": config.bias,
"lse": config.lse,
"dropout": config.dropout,
"logits": config.logits,
"sink": config.sink,
"skip": config.skip_min_seqlen_q,
"qscale": config.qscale,
"paged_kv": config.paged_kv,
"rope": config.rope,
"deterministic": config.deterministic,
"dbias": config.dbias,
"mask_int": MASK_INT.get(config.mask, 0),
"bias_int": BIAS_INT.get(config.bias, 0),
"has_lse": int(config.lse),
"has_dropout": int(config.dropout not in (False, 0, "no", "False")),
"has_logits": int(config.logits),
"has_sink": int(config.sink),
"has_skip": int(config.skip_min_seqlen_q),
"has_dbias": int(getattr(config, "dbias", False)),
"is_store_randval": int(getattr(config, "store_randval", False)),
"page_size": getattr(config, "page_size", 16),
"kv_layout": KV_LAYOUT_INT.get(
getattr(config, "kv_memory_layout", "vectorized"), 0
),
"kv_lookup": KV_LOOKUP_INT.get(getattr(config, "kv_lookup_table", "sglang"), 1),
}
def _shape_to_dict(shape: TestShape) -> dict:
return {
"name": shape.name,
"category": shape.category,
"variant": shape.variant,
"batch": shape.batch,
"seqlen_q": shape.seqlen_q,
"seqlen_k": shape.seqlen_k,
"nhead_q": shape.nhead_q,
"nhead_k": shape.nhead_k,
"hdim_q": shape.hdim_q,
"hdim_v": shape.hdim_v,
"dtype": shape.dtype,
"mask": shape.mask,
"bias": shape.bias,
"dropout": shape.dropout,
"lse": shape.lse,
}
def main():
p = argparse.ArgumentParser(description="Full FMHA Benchmark Sweep")
p.add_argument("--arch", default=detect_gpu_arch())
p.add_argument("--category", default="smoke", choices=["smoke", "full", "nightly"])
p.add_argument("--variant", default="all")
p.add_argument("--workers", type=int, default=8)
p.add_argument("--build-dir", default="/tmp/fmha_full_bench")
p.add_argument("--filter", dest="filter_expr", default="")
p.add_argument("--filter-file", default="")
p.add_argument("--csv", default="fmha_sweep_results.csv")
p.add_argument("--json", default="fmha_sweep_results.json")
p.add_argument("--compile-only", action="store_true")
p.add_argument("--max-kernels", type=int, default=0)
p.add_argument(
"--shape-timeout",
type=int,
default=600,
help="Per-shape timeout in seconds (0=none)",
)
args = p.parse_args()
build_dir = Path(args.build_dir)
build_dir.mkdir(parents=True, exist_ok=True)
variants = list(VARIANT_CONFIGS.keys()) if args.variant == "all" else [args.variant]
# ---- Phase 1: Parse shapes ----
print(f"\n{'=' * 80}")
print("Phase 1: Parse test shapes")
print(f"{'=' * 80}")
all_shapes: List[TestShape] = []
for variant in variants:
sections = VARIANT_YAML_SECTIONS.get(variant, ["forward_tests"])
vshapes = parse_yaml(YAML_PATH, args.category, sections=sections)
for s in vshapes:
s.variant = variant
all_shapes.extend(vshapes)
print(f" Category: {args.category}")
print(f" Variants: {variants}")
print(f" Total shapes: {len(all_shapes)}")
# ---- Phase 2: Compile ----
print(f"\n{'=' * 80}")
print("Phase 2: Compile kernels")
print(f"{'=' * 80}")
# kernel_index: (hdim_q, hdim_v, dtype, variant) -> list of (so_path, cfg_dict)
kernel_index: Dict[tuple, List[tuple]] = {}
from concurrent.futures import ProcessPoolExecutor as _PPE
_compile_pool = _PPE(max_workers=args.workers)
BATCH_SIZE = 200
for variant in variants:
cfg_path = str(_THIS_DIR / VARIANT_CONFIGS[variant])
if not Path(cfg_path).exists():
continue
configs = expand_sweep(cfg_path, args.arch)
if args.filter_expr or args.filter_file:
configs = apply_filter(configs, args.filter_expr, args.filter_file)
if args.max_kernels > 0:
configs = configs[: args.max_kernels]
if not configs:
continue
n_batches = (len(configs) + BATCH_SIZE - 1) // BATCH_SIZE
print(
f"\n {variant}: {len(configs)} configs, {args.workers} workers, {n_batches} batches..."
)
t0 = time.perf_counter()
setups = []
total_ok = 0
for bi in range(n_batches):
batch_cfgs = configs[bi * BATCH_SIZE : (bi + 1) * BATCH_SIZE]
batch_setups = setup_multiple_fmha_dispatchers(
batch_cfgs,
output_dir=build_dir,
max_workers=args.workers,
executor=_compile_pool,
)
batch_ok = sum(1 for s in batch_setups if s.success)
batch_n = len(batch_cfgs)
total_ok += batch_ok
setups.extend(zip(batch_cfgs, batch_setups))
del batch_setups, batch_cfgs
print(
f" Batch {bi + 1}/{n_batches}: {batch_ok}/{batch_n} "
f"(total {total_ok}, {time.perf_counter() - t0:.0f}s)",
flush=True,
)
ok = total_ok
print(f" Built {ok}/{len(configs)} in {time.perf_counter() - t0:.0f}s")
for config, setup in setups:
if not setup.success:
continue
so_path = getattr(setup, "library_path", "") or ""
if not so_path:
candidate = build_dir / f"libdispatcher_fmha_{config.name}.so"
if candidate.exists():
so_path = str(candidate)
if not so_path:
continue
cfg_dict = _config_to_serializable(config, so_path)
key = (config.hdim_q, config.hdim_v, config.data_type, variant, config.mode)
kernel_index.setdefault(key, []).append((so_path, cfg_dict))
_compile_pool.shutdown(wait=True)
del _compile_pool
total_built = sum(len(v) for v in kernel_index.values())
print(f"\n Total compiled: {total_built}")
print(f" Unique (hdim,dtype,variant) groups: {len(kernel_index)}")
if args.compile_only:
print(f"\n Compile-only. {total_built} kernels ready.")
return
# ---- Phase 3: Benchmark (serial, one subprocess per kernel) ----
print(f"\n{'=' * 80}")
print("Phase 3: Benchmark (one subprocess per kernel, serial GPU)")
print(f"{'=' * 80}")
csv_path = Path(args.csv) if os.path.isabs(args.csv) else _THIS_DIR / args.csv
csv_fields = [
"problem_name",
"batch",
"seqlen_q",
"seqlen_k",
"nhead_q",
"nhead_k",
"hdim_q",
"hdim_v",
"dtype",
"kernel",
"family",
"mode",
"pipeline",
"tile_m0",
"tile_n0",
"tile_k0",
"tile_n1",
"tile_k1",
"tile_k0max",
"pad_s",
"pad_sk",
"pad_d",
"pad_dv",
"mask",
"bias",
"lse",
"dropout",
"logits",
"sink",
"skip",
"qscale",
"paged_kv",
"rope",
"deterministic",
"dbias",
"latency_ms",
"tflops",
"bandwidth_gb_s",
]
# Resume: load already-completed measurements
completed: set = set()
if csv_path.exists() and csv_path.stat().st_size > 0:
with open(csv_path, newline="") as f:
for row in csv.DictReader(f):
completed.add(
(
row.get("kernel", ""),
row.get("problem_name", ""),
str(row.get("batch", "")),
str(row.get("seqlen_q", "")),
row.get("dtype", ""),
)
)
csv_file = open(csv_path, "a", newline="")
writer = csv.DictWriter(csv_file, fieldnames=csv_fields)
print(f" Resuming: {len(completed)} measurements already in CSV")
else:
csv_file = open(csv_path, "w", newline="")
writer = csv.DictWriter(csv_file, fieldnames=csv_fields)
writer.writeheader()
# Pre-filter: match shapes to kernels by (hdim, dtype, variant, mode).
# YAML shapes are batch-mode only. Group-mode kernels need seqstart arrays
# which batch shapes don't provide -- they all GPU fault.
runnable = []
for shape in all_shapes:
ck_dtype = DTYPE_CK.get(shape.dtype, shape.dtype)
key = (shape.hdim_q, shape.hdim_v, ck_dtype, shape.variant, "batch")
entries = kernel_index.get(key, [])
if entries:
runnable.append((shape, entries))
# Flatten to work items, skip already-completed
def _resume_key(cfg, shape):
return (
cfg.get("kernel", ""),
shape.name,
str(shape.batch),
str(shape.seqlen_q),
shape.dtype,
)
work_items = []
skipped = 0
for shape, kernel_entries in runnable:
for so_path, cfg in kernel_entries:
if _resume_key(cfg, shape) in completed:
skipped += 1
else:
work_items.append((shape, so_path, cfg))
total_work = len(work_items) + skipped
total_measurements = len(completed)
total_gpu_faults = 0
bench_t0 = time.perf_counter()
BENCH_BATCH = 50
worker_path = _THIS_DIR / "run_one_kernel.py"
worker_env = os.environ.copy()
worker_env["FMHA_PYPATH_1"] = str(_DISPATCHER_ROOT / "python")
worker_env["FMHA_PYPATH_2"] = str(_DISPATCHER_ROOT / "codegen")
CFG_KEYS = [
"kernel",
"family",
"mode",
"pipeline",
"tile_m0",
"tile_n0",
"tile_k0",
"tile_n1",
"tile_k1",
"tile_k0max",
"pad_s",
"pad_sk",
"pad_d",
"pad_dv",
"mask",
"bias",
"lse",
"dropout",
"logits",
"sink",
"skip",
"qscale",
"paged_kv",
"rope",
"deterministic",
"dbias",
]
print(f" Runnable shapes: {len(runnable)}")
print(f" Total kernel x shape pairs: {total_work}")
print(f" Already completed: {skipped}")
print(f" Pending: {len(work_items)}")
print(f" Batch size: {BENCH_BATCH} (retry individually on fault)")
print()
def _run_subprocess(payload_bytes, timeout=10):
proc = subprocess.Popen(
[sys.executable, str(worker_path)],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
env=worker_env,
)
timed_out = False
stdout_bytes = b""
try:
stdout_bytes, _ = proc.communicate(input=payload_bytes, timeout=timeout)
except subprocess.TimeoutExpired:
proc.kill()
proc.communicate()
timed_out = True
finally:
pid = proc.pid
if proc.poll() is None:
proc.kill()
proc.wait()
for pipe in [proc.stdin, proc.stdout, proc.stderr]:
if pipe and not pipe.closed:
pipe.close()
gpucore = _THIS_DIR / f"gpucore.{pid}"
if gpucore.exists():
gpucore.unlink(missing_ok=True)
rc = -1 if timed_out else proc.returncode
return stdout_bytes, rc
def _record_result(r, shape, cfg, shape_dict):
nonlocal total_measurements
lat_ms, tflops = r["ms"], r["tflops"]
bw = bandwidth_gb_s(shape, lat_ms)
row = {
"problem_name": shape.name,
"batch": shape.batch,
"seqlen_q": shape.seqlen_q,
"seqlen_k": shape.seqlen_k,
"nhead_q": shape.nhead_q,
"nhead_k": shape.nhead_k,
"hdim_q": shape.hdim_q,
"hdim_v": shape.hdim_v,
"dtype": shape.dtype,
}
for k in CFG_KEYS:
row[k] = cfg.get(k, "")
row["latency_ms"] = round(lat_ms, 4)
row["tflops"] = round(tflops, 2)
row["bandwidth_gb_s"] = round(bw, 2)
writer.writerow(row)
csv_file.flush()
total_measurements += 1
return tflops, lat_ms
# Process in batches
n_batches = (len(work_items) + BENCH_BATCH - 1) // BENCH_BATCH
processed = 0
for bi in range(n_batches):
batch = work_items[bi * BENCH_BATCH : (bi + 1) * BENCH_BATCH]
items = []
for shape, so_path, cfg in batch:
cfg["so_path"] = so_path
items.append(
{"so_path": so_path, "shape": _shape_to_dict(shape), "cfg": cfg}
)
batch_timeout = len(batch) * 2 + 5
payload = json.dumps({"items": items}).encode()
stdout_bytes, rc = _run_subprocess(payload, timeout=batch_timeout)
if rc == 0:
batch_ok = 0
for line in stdout_bytes.decode().strip().split("\n"):
if not line:
continue
try:
r = json.loads(line)
except (json.JSONDecodeError, ValueError):
continue
idx = r.get("idx", -1)
if not r.get("ok") or idx < 0 or idx >= len(batch):
continue
shape, so_path, cfg = batch[idx]
_record_result(r, shape, cfg, items[idx]["shape"])
batch_ok += 1
processed += len(batch)
print(
f" [batch {bi + 1}/{n_batches}] {batch_ok}/{len(batch)} ok "
f"({processed}/{len(work_items)} done, {total_measurements} total)",
flush=True,
)
else:
# Collect partial results flushed before the fault
partial_done = set()
for line in stdout_bytes.decode().strip().split("\n"):
if not line:
continue
try:
r = json.loads(line)
except (json.JSONDecodeError, ValueError):
continue
idx = r.get("idx", -1)
if r.get("ok") and 0 <= idx < len(batch):
shape, so_path, cfg = batch[idx]
_record_result(r, shape, cfg, items[idx]["shape"])
partial_done.add(idx)
# Retry the rest one by one
retry = [(i, b) for i, b in enumerate(batch) if i not in partial_done]
print(
f" [batch {bi + 1}/{n_batches}] FAULT after {len(partial_done)}/{len(batch)} ok, "
f"retrying {len(retry)} individually...",
flush=True,
)
for idx, (shape, so_path, cfg) in retry:
cfg["so_path"] = so_path
p = json.dumps(
{"so_path": so_path, "shape": items[idx]["shape"], "cfg": cfg}
).encode()
out, single_rc = _run_subprocess(p, timeout=10)
if single_rc != 0:
total_gpu_faults += 1
continue
try:
r = json.loads(out.decode().strip().split("\n")[0])
except (json.JSONDecodeError, ValueError):
continue
if r.get("ok"):
tflops, lat_ms = _record_result(r, shape, cfg, items[idx]["shape"])
print(
f" {tflops:>7.1f} TFLOPS {lat_ms:.4f}ms "
f"{cfg.get('kernel', '?')[:45]} | {shape.name}",
flush=True,
)
processed += len(batch)
print(f" ({processed}/{len(work_items)} done)", flush=True)
csv_file.close()
bench_time = time.perf_counter() - bench_t0
# ---- Phase 4: Summary ----
print(f"\n{'=' * 80}")
print("Results")
print(f"{'=' * 80}")
print(f" Total work items: {total_work}")
print(f" Skipped (resumed): {skipped}")
print(f" Measurements: {total_measurements}")
print(f" GPU faults: {total_gpu_faults}")
print(f" Benchmark time: {bench_time:.1f}s")
print(f" CSV: {csv_path}")
print(f"{'=' * 80}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,175 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Full FMHA benchmark sweep, organized by variant and dtype.
Compiles all kernels per variant (shared build dir for caching),
benchmarks against all smoke shapes, then splits results into:
<output_dir>/
fwd/fp16/results.csv
fwd/bf16/results.csv
splitkv/fp16/results.csv
...
bwd_dot_do_o/fp16/results.csv
bwd_dq_dk_dv/fp16/results.csv
bwd_convert_dq/fp16/results.csv
Usage:
python run_full_sweep.py --workers 256
python run_full_sweep.py --workers 256 --category full --output /tmp/fmha_sweep
"""
import argparse
import csv
import os
import subprocess
import sys
import time
from collections import defaultdict
from pathlib import Path
_THIS_DIR = Path(__file__).resolve().parent
VARIANTS = ["fwd", "splitkv", "pagedkv", "appendkv", "batch_prefill", "bwd"]
BWD_FAMILIES = ["bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"]
def run_variant(variant, category, workers, build_dir, raw_csv, shape_timeout=600):
"""Run fmha_full_benchmark.py for one variant, return path to raw CSV."""
cmd = [
sys.executable,
str(_THIS_DIR / "fmha_full_benchmark.py"),
"--category",
category,
"--variant",
variant,
"--workers",
str(workers),
"--build-dir",
str(build_dir),
"--csv",
str(raw_csv),
"--json",
str(raw_csv.with_suffix(".json")),
"--shape-timeout",
str(shape_timeout),
]
print(f"\n{'=' * 80}")
print(f" Variant: {variant}")
print(f" Command: {' '.join(cmd)}")
print(f"{'=' * 80}", flush=True)
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
proc = subprocess.run(cmd, env=env)
return proc.returncode
def split_csv(raw_csv, output_dir):
"""Split a raw CSV into per-family per-dtype subdirectories."""
if not raw_csv.exists():
return {}
counts = defaultdict(int)
writers = {}
files = {}
with open(raw_csv, newline="") as f:
reader = csv.DictReader(f)
fieldnames = reader.fieldnames
for row in reader:
family = row.get("family", "unknown")
dtype = row.get("dtype", "unknown")
key = (family, dtype)
if key not in writers:
d = output_dir / family / dtype
d.mkdir(parents=True, exist_ok=True)
fh = open(d / "results.csv", "w", newline="")
w = csv.DictWriter(fh, fieldnames=fieldnames)
w.writeheader()
writers[key] = w
files[key] = fh
writers[key].writerow(row)
counts[key] += 1
for fh in files.values():
fh.close()
return counts
def main():
p = argparse.ArgumentParser(
description="Full FMHA Sweep (organized by variant/dtype)"
)
p.add_argument("--workers", type=int, default=256)
p.add_argument("--category", default="smoke", choices=["smoke", "full", "nightly"])
p.add_argument("--output", default="/tmp/fmha_sweep")
p.add_argument("--build-dir", default="/tmp/fmha_sweep_build")
p.add_argument(
"--variants",
nargs="+",
default=VARIANTS,
choices=VARIANTS,
help="Which variants to run",
)
p.add_argument(
"--shape-timeout", type=int, default=600, help="Per-shape timeout in seconds"
)
args = p.parse_args()
output_dir = Path(args.output)
build_dir = Path(args.build_dir)
output_dir.mkdir(parents=True, exist_ok=True)
build_dir.mkdir(parents=True, exist_ok=True)
t0 = time.perf_counter()
grand_total = defaultdict(int)
for variant in args.variants:
raw_csv = output_dir / f"_raw_{variant}.csv"
rc = run_variant(
variant, args.category, args.workers, build_dir, raw_csv, args.shape_timeout
)
if rc != 0:
print(f"\n WARNING: {variant} exited with code {rc}", flush=True)
counts = split_csv(raw_csv, output_dir)
for key, n in counts.items():
grand_total[key] += n
family, dtype = key
print(f" {family}/{dtype}: {n} measurements")
elapsed = time.perf_counter() - t0
print(f"\n{'=' * 80}")
print("SWEEP COMPLETE")
print(f"{'=' * 80}")
print(f" Total time: {elapsed / 60:.1f} min")
print(f" Output dir: {output_dir}")
print()
print(f" {'Family':<25} {'Dtype':<10} {'Measurements':>12}")
print(f" {'-' * 25} {'-' * 10} {'-' * 12}")
total = 0
for (family, dtype), n in sorted(grand_total.items()):
print(f" {family:<25} {dtype:<10} {n:>12,}")
total += n
print(f" {'-' * 25} {'-' * 10} {'-' * 12}")
print(f" {'TOTAL':<25} {'':<10} {total:>12,}")
print("\n Directory structure:")
for d in sorted(output_dir.rglob("results.csv")):
rel = d.relative_to(output_dir)
print(f" {rel}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,128 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Run FMHA kernel(s) on GPU and report timing.
Single mode: stdin = {"so_path": ..., "shape": {...}, "cfg": {...}}
Batch mode: stdin = {"items": [{"so_path": ..., "shape": {...}, "cfg": {...}}, ...]}
Each result prints one JSON line to stdout (flushed immediately):
{"idx": 0, "ok": true, "ms": 0.123, "tflops": 456.7}
{"idx": 1, "ok": false}
Flushing per-line lets the parent recover partial results if a later
kernel causes a GPU fault that kills this process.
"""
import json
import os
import sys
import numpy as np
for p in [os.environ.get("FMHA_PYPATH_1", ""), os.environ.get("FMHA_PYPATH_2", "")]:
if p and p not in sys.path:
sys.path.insert(0, p)
from fmha_utils import FmhaProblem, FmhaRunner # noqa: E402
DTYPE_NP = {
"fp16": np.float16,
"bf16": np.float16,
"fp32": np.float32,
"fp8bf16": np.float16,
"fp8fp32": np.float16,
}
def _run_one(idx, so_path, s, cfg):
prob = FmhaProblem(
batch=s["batch"],
nhead_q=s["nhead_q"],
nhead_k=s["nhead_k"],
seqlen_q=s["seqlen_q"],
seqlen_k=s["seqlen_k"],
hdim_q=s["hdim_q"],
hdim_v=s["hdim_v"],
)
dt = DTYPE_NP.get(s.get("dtype", "fp16"), np.float16)
np.random.seed(42)
q = (np.random.randn(*prob.q_shape()) * 0.1).astype(dt)
k = (np.random.randn(*prob.k_shape()) * 0.1).astype(dt)
v = (np.random.randn(*prob.v_shape()) * 0.1).astype(dt)
runner = FmhaRunner.from_library(so_path)
api = cfg.get("api_family", "fwd")
if api == "bwd":
out_buf = (
np.random.randn(s["batch"], s["nhead_q"], s["seqlen_q"], s["hdim_v"]) * 0.1
).astype(dt)
lse = np.random.randn(s["batch"], s["nhead_q"], s["seqlen_q"]).astype(
np.float32
)
d_out = (np.random.randn(*out_buf.shape) * 0.1).astype(dt)
result = runner.run_bwd(
q,
k,
v,
out_buf,
lse,
d_out,
prob,
data_type=cfg.get("data_type", "fp16"),
mask_type=cfg.get("mask_int", 0),
bias_type=cfg.get("bias_int", 0),
has_dropout=cfg.get("has_dropout", 0),
has_dbias=cfg.get("has_dbias", 0),
is_deterministic=cfg.get("deterministic", 0),
is_group_mode=cfg.get("mode", "batch") == "group",
is_store_randval=cfg.get("is_store_randval", 0),
tile_n0=cfg.get("tile_n0", 128),
)
else:
result = runner.run(
q,
k,
v,
prob,
mask_type=cfg.get("mask_int", 0),
bias_type=cfg.get("bias_int", 0),
has_lse=cfg.get("has_lse", 0),
has_dropout=cfg.get("has_dropout", 0),
has_logits=cfg.get("has_logits", 0),
has_sink=cfg.get("has_sink", 0),
has_skip=cfg.get("has_skip", 0),
api_family=api,
data_type=cfg.get("data_type", "fp16"),
page_size=cfg.get("page_size", 16),
kv_layout=cfg.get("kv_layout", 0),
kv_lookup=cfg.get("kv_lookup", 1),
is_group_mode=cfg.get("mode", "batch") == "group",
)
if result.success:
print(
json.dumps(
{"idx": idx, "ok": True, "ms": result.time_ms, "tflops": result.tflops}
),
flush=True,
)
else:
print(json.dumps({"idx": idx, "ok": False}), flush=True)
def main():
d = json.loads(sys.stdin.buffer.read())
if "items" in d:
for i, item in enumerate(d["items"]):
_run_one(i, item["so_path"], item["shape"], item["cfg"])
else:
_run_one(0, d["cfg"]["so_path"], d["shape"], d["cfg"])
if __name__ == "__main__":
main()