mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 08:48:45 +00:00
[rocm-libraries] ROCm/rocm-libraries#5260 (commit a1834d2)
[CK] [CK_Tile] Add FMHA scaffolding to CK kernel dispatcher (#5260) ## Motivation The CK Tile dispatcher currently supports GEMM and Grouped Convolution but has no support for Fused Multi-Head Attention (FMHA). The example/ck_tile/01_fmha folder contains a comprehensive FMHA implementation with forward, backward, split-KV, paged-KV, append-KV, and batch-prefill kernels across multiple GPU architectures — but there is no unified dispatch layer for it. This PR ports the FMHA stack into the dispatcher, following the same architectural patterns established by GEMM and Grouped Convolution, enabling runtime kernel selection, JIT compilation from Python, and a declarative C++ example flow. Autotuning heuristics to follow. ## Technical Details This PR adds FMHA scaffolding to the CK dispatcher framework, mirroring GEMM's layered architecture. Seven new C++ runtime headers provide type definitions (coexisting with upstream headers via __has_include, requiring zero modifications to example/ck_tile/01_fmha/), a problem builder with 18+ setters, Signature + Algorithm kernel key matching, a virtual kernel instance, a DECL_FMHA_KERNEL_SET macro with wildcard support and named tile/wave/warp setters, arch-aware registry with JSON export, and a dispatcher with seqtune-aware selection, configurable timing, and multi-stage execution plans for split-KV (two-stage) and backward (three-stage). The codegen pipeline is driven by a fmha_arch_specs.json capturing per-arch tile tables and pipeline constraints for five architectures (gfx90a/942/950/1100/1201), migrated from hardcoded logic in 01_fmha/codegen/, with supporting modules for C++ symbol mappings, validation rules, and named receipt profiles (ck_default, flash, pytorch, aiter, fp32, fp8). Python integration (fmha_utils.py) mirrors the C++ layer with JIT compilation, parallel multi-kernel builds, HIP memory management via ctypes, tolerance-based validation, and a NumPy CPU reference with GQA support. Twenty-seven C++ and thirty-two Python examples cover the full feature surface — forward, split-KV, masks, bias, dropout, GQA, backward, append-KV, batch prefill, fp8, logits soft cap, sink tokens, and parameter sweeps — all JIT-compiled on the fly. ## Test Plan Seven test files cover the runtime types, codegen, and end-to-end correctness. C++ unit tests validate the problem builder, dispatcher planning (single-stage for forward/paged-KV/append-KV; multi-stage for split-KV and backward), registry operations, and the kernel-set declaration macro. Python unit tests verify codegen emission, profile filtering, and 15 validation rules for masks, hdim constraints, and pipeline requirements. GPU execution validation in 01_basic_fmha --validate reports zero errors across 65,536 elements with max absolute error of 7.29e-05. A gold-standard parity suite (test_fmha_parity.py) runs 14 configurations through both the upstream tile_example_fmha_fwd and the dispatcher, comparing exit codes to confirm behavioral parity — all 14 match. ## Test Result The C++ smoke test builds and passes all 9 compiled examples, and a Python JIT sweep (29_sweep_seqlen.py) passes 7/7 configurations reaching up to 375 TFLOPS at seqlen 2048. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Yaswanth Raparti <113389104+yraparti@users.noreply.github.com> Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com> Co-authored-by: Maksim (Max) Podkorytov <Maksim.Podkorytov@amd.com> Co-authored-by: yashagar <yashagar@amd.com>
This commit is contained in:
committed by
GitHub
parent
cc5c79a1e7
commit
b20458e19e
@@ -6,6 +6,7 @@ include_directories(BEFORE
|
||||
${CMAKE_CURRENT_LIST_DIR}/ops
|
||||
)
|
||||
|
||||
add_subdirectory(ops/fmha EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(ops/gemm EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(ops/gemm_streamk EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(ops/pooling EXCLUDE_FROM_ALL)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
| GEMM | grouped_gemm_quant | | ❌ | | ❌ | | | | ❌ | | | | ❌ | ❌ | ❌ | ❌ |
|
||||
| Reduce | multi_reduce2d [8]<br>engine: reduce/<br>example: 05_reduce/ | ✅ | | ❌ | | | | | | | | | ❌ | ✅ | ✅ | ❌ |
|
||||
| Reduce | reduce2d<br>example: 05_reduce/ | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
|
||||
| Attention | fmha<br>example: 01_fmha/ | ❌ | ❌ | ❌ | ❌ | | | | | | | | ❌ | ❌ | ❌ | ❌ |
|
||||
| Attention | fmha<br>engine: fmha/<br>example: 01_fmha/ | ✅ | ✅ | ✅ | ❌ | | | | | | | | ✅ | ✅ | ✅ | ❌ |
|
||||
| Attention | sparse_attn<br>example: 50_sparse_attn/ | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
|
||||
| Activation | softmax | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
|
||||
| Activation | topk_softmax<br>example: 09_topk_softmax/ | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
|
||||
|
||||
63
tile_engine/ops/common/parallel_runner.py
Normal file
63
tile_engine/ops/common/parallel_runner.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Generic multi-GPU parallel job runner for tile engine benchmarks.
|
||||
|
||||
Op-agnostic: takes opaque jobs, distributes them across GPUs with one
|
||||
job per GPU at a time, and yields results in completion order. Used by
|
||||
fmha_benchmark.py and reusable for gemm/reduce/pooling benchmarks.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Callable, Iterator, List, Optional, Tuple
|
||||
|
||||
|
||||
def run_parallel_on_gpus(
|
||||
jobs: List[Any],
|
||||
gpu_ids: List[int],
|
||||
run_one: Callable[[Any, int], Any],
|
||||
max_workers: Optional[int] = None,
|
||||
) -> Iterator[Tuple[int, Any]]:
|
||||
"""Dispatch jobs across GPUs, one job per GPU at a time.
|
||||
|
||||
Args:
|
||||
jobs: Opaque job objects passed to run_one.
|
||||
gpu_ids: GPU IDs to use (e.g. [0,1,2,3]). At most one job per GPU runs concurrently.
|
||||
run_one: Callable run_one(job, gpu_id) -> result. Caller is responsible
|
||||
for any subprocess isolation, environment setup, etc.
|
||||
max_workers: Thread pool size. Defaults to len(gpu_ids).
|
||||
|
||||
Yields:
|
||||
(job_index, result) tuples in completion order. Caller can sort by
|
||||
job_index to restore submission order if needed.
|
||||
"""
|
||||
if not jobs:
|
||||
return
|
||||
if max_workers is None:
|
||||
max_workers = len(gpu_ids)
|
||||
|
||||
# One job per GPU at a time
|
||||
gpu_semas = {gid: threading.Semaphore(1) for gid in gpu_ids}
|
||||
cycle = [0]
|
||||
cycle_lock = threading.Lock()
|
||||
|
||||
def _pick_gpu() -> int:
|
||||
with cycle_lock:
|
||||
gid = gpu_ids[cycle[0] % len(gpu_ids)]
|
||||
cycle[0] += 1
|
||||
return gid
|
||||
|
||||
def _wrapper(job_idx: int, job: Any) -> Tuple[int, Any]:
|
||||
gid = _pick_gpu()
|
||||
gpu_semas[gid].acquire()
|
||||
try:
|
||||
return job_idx, run_one(job, gid)
|
||||
finally:
|
||||
gpu_semas[gid].release()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [pool.submit(_wrapper, i, j) for i, j in enumerate(jobs)]
|
||||
for fut in as_completed(futures):
|
||||
yield fut.result()
|
||||
3
tile_engine/ops/fmha/.gitignore
vendored
Normal file
3
tile_engine/ops/fmha/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
*.log
|
||||
build/
|
||||
*_build*/
|
||||
94
tile_engine/ops/fmha/CMakeLists.txt
Normal file
94
tile_engine/ops/fmha/CMakeLists.txt
Normal file
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# FMHA Tile Engine -- Pure Python benchmarking via the CK dispatcher.
|
||||
# No C++ per-kernel targets; all compilation is JIT via the dispatcher.
|
||||
|
||||
set(FMHA_TE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
set(FMHA_TE_CONFIGS ${FMHA_TE_DIR}/configs)
|
||||
|
||||
include(ProcessorCount)
|
||||
ProcessorCount(NPROC)
|
||||
if(NPROC EQUAL 0)
|
||||
set(NPROC 8)
|
||||
endif()
|
||||
|
||||
# Use first arch from SUPPORTED_GPU_TARGETS, or fallback to gfx950
|
||||
set(FMHA_BENCH_ARCH "gfx950")
|
||||
if(SUPPORTED_GPU_TARGETS)
|
||||
list(GET SUPPORTED_GPU_TARGETS 0 FMHA_BENCH_ARCH)
|
||||
endif()
|
||||
|
||||
# Main benchmark target (runs forward sweep by default)
|
||||
add_custom_target(benchmark_fmha
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py
|
||||
${FMHA_TE_CONFIGS}/fwd.json
|
||||
--arch ${FMHA_BENCH_ARCH}
|
||||
--workers ${NPROC}
|
||||
--best
|
||||
--json ${CMAKE_CURRENT_BINARY_DIR}/fmha_fwd_results.json
|
||||
WORKING_DIRECTORY ${FMHA_TE_DIR}
|
||||
COMMENT "FMHA tile engine benchmark (forward)"
|
||||
)
|
||||
|
||||
if(TARGET ck_tile_dispatcher)
|
||||
add_dependencies(benchmark_fmha ck_tile_dispatcher)
|
||||
endif()
|
||||
|
||||
# Per-variant convenience targets
|
||||
foreach(variant fwd bwd splitkv appendkv pagedkv batch_prefill)
|
||||
if(EXISTS ${FMHA_TE_CONFIGS}/${variant}.json)
|
||||
add_custom_target(benchmark_fmha_${variant}
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py
|
||||
${FMHA_TE_CONFIGS}/${variant}.json
|
||||
--arch ${FMHA_BENCH_ARCH}
|
||||
--workers ${NPROC}
|
||||
--best
|
||||
--json ${CMAKE_CURRENT_BINARY_DIR}/fmha_${variant}_results.json
|
||||
WORKING_DIRECTORY ${FMHA_TE_DIR}
|
||||
COMMENT "FMHA tile engine benchmark (${variant})"
|
||||
)
|
||||
if(TARGET ck_tile_dispatcher)
|
||||
add_dependencies(benchmark_fmha_${variant} ck_tile_dispatcher)
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# CI target (minimal sweep for quick validation)
|
||||
if(EXISTS ${FMHA_TE_CONFIGS}/fwd_ci.json)
|
||||
add_custom_target(benchmark_fmha_ci
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py
|
||||
${FMHA_TE_CONFIGS}/fwd_ci.json
|
||||
--arch ${FMHA_BENCH_ARCH}
|
||||
--workers 8
|
||||
--verify
|
||||
WORKING_DIRECTORY ${FMHA_TE_DIR}
|
||||
COMMENT "FMHA tile engine CI benchmark"
|
||||
)
|
||||
if(TARGET ck_tile_dispatcher)
|
||||
add_dependencies(benchmark_fmha_ci ck_tile_dispatcher)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# All-variants target
|
||||
set(FMHA_ALL_CONFIGS "")
|
||||
foreach(cfg fwd bwd splitkv appendkv pagedkv batch_prefill)
|
||||
if(EXISTS ${FMHA_TE_CONFIGS}/${cfg}.json)
|
||||
list(APPEND FMHA_ALL_CONFIGS ${FMHA_TE_CONFIGS}/${cfg}.json)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_custom_target(benchmark_fmha_all
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py
|
||||
${FMHA_ALL_CONFIGS}
|
||||
--arch ${FMHA_BENCH_ARCH}
|
||||
--workers ${NPROC}
|
||||
--best
|
||||
--json ${CMAKE_CURRENT_BINARY_DIR}/fmha_all_results.json
|
||||
WORKING_DIRECTORY ${FMHA_TE_DIR}
|
||||
COMMENT "FMHA tile engine benchmark (all variants)"
|
||||
)
|
||||
|
||||
if(TARGET ck_tile_dispatcher)
|
||||
add_dependencies(benchmark_fmha_all ck_tile_dispatcher)
|
||||
endif()
|
||||
192
tile_engine/ops/fmha/README.md
Normal file
192
tile_engine/ops/fmha/README.md
Normal file
@@ -0,0 +1,192 @@
|
||||
# FMHA Tile Engine
|
||||
|
||||
Benchmarking and kernel enumeration for Fused Multi-Head Attention (FMHA) via the CK dispatcher's pipelined JIT compilation.
|
||||
|
||||
Covers all 9 FMHA kernel families: Forward, Split-KV (main + combine), Paged-KV, Append-KV, Batch Prefill, and Backward (dot\_do\_o, dq\_dk\_dv, convert\_dq) -- totaling 33,541 unique kernel specializations on gfx950.
|
||||
|
||||
## Directory Layout
|
||||
|
||||
```
|
||||
fmha/
|
||||
fmha_instance_builder.py Kernel enumeration from JSON config + pipeline rules
|
||||
fmha_benchmark.py Single-config JIT compile and GPU benchmark
|
||||
fmha_full_benchmark.py Full sweep: compile all kernels, benchmark across test shapes
|
||||
ck_fmha_testing_matrix.yaml Test shapes (smoke / full / nightly)
|
||||
CMakeLists.txt CMake targets
|
||||
README.md This file
|
||||
configs/ Sweep definitions (JSON)
|
||||
receipt0_fwd.json Full receipt-0 forward: ~12K kernels
|
||||
fwd.json Forward variants
|
||||
fwd_ci.json Minimal CI subset
|
||||
bwd.json Backward variants
|
||||
splitkv.json Split-KV
|
||||
appendkv.json Append-KV
|
||||
pagedkv.json Paged-KV
|
||||
batch_prefill.json Batch prefill
|
||||
filters/ Sample Python filter scripts
|
||||
h128_no_dropout.py Keep only h128 without dropout
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Count kernels without compiling
|
||||
python fmha_instance_builder.py configs/receipt0_fwd.json --count-only
|
||||
|
||||
# Minimal CI build + run (~16 kernels, <1 min)
|
||||
python fmha_benchmark.py configs/fwd_ci.json --workers 128 --verify
|
||||
|
||||
# Full forward receipt-0 compile-only (12K kernels, ~10 min with 256 workers)
|
||||
python fmha_benchmark.py configs/receipt0_fwd.json --workers 256 --compile-only
|
||||
|
||||
# Full sweep: compile every fwd kernel, benchmark against all smoke shapes
|
||||
python fmha_full_benchmark.py --category smoke --variant fwd --workers 256
|
||||
|
||||
# Quick end-to-end test (2 kernels, 1 shape)
|
||||
python fmha_full_benchmark.py --category smoke --variant fwd --max-kernels 2 --workers 4
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Kernel Enumeration
|
||||
|
||||
```
|
||||
JSON config (variant + trait_config allow-list)
|
||||
--> fmha_instance_builder.py
|
||||
--> fmha_pipeline_rules.py (self-contained CK parity logic)
|
||||
--> fmha_arch_specs.json (tile tables per arch / dtype / hdim)
|
||||
--> list of FmhaKernelConfig (33,541 total on gfx950)
|
||||
--> optional --filter / --filter-file
|
||||
```
|
||||
|
||||
The pipeline rules in `dispatcher/codegen/fmha_pipeline_rules.py` reproduce the exact kernel enumeration from CK Tile's `01_fmha/codegen/`, including per-arch tile constraints, pipeline selection, padding variants, and feature products. Parity is verified by `dispatcher/tests/validate_arch_specs_parity.py`.
|
||||
|
||||
### Benchmark Tools
|
||||
|
||||
**`fmha_benchmark.py`** -- single-config benchmark. Input: one JSON config (kernel definitions). JIT-compiles all matching kernels, runs each on a given problem size, reports per-kernel timing and optional CPU validation. Optionally writes `--csv` output.
|
||||
|
||||
**`fmha_full_benchmark.py`** -- full sweep benchmark. Input: `ck_fmha_testing_matrix.yaml` (test shapes) + JSON configs (kernel definitions). Compiles all kernel variants for selected families, then iterates over test shapes, matching each shape to compatible compiled kernels and benchmarking every match. Writes `--csv` and `--json` output.
|
||||
|
||||
### JIT Compilation Pipeline
|
||||
|
||||
Both tools use the dispatcher's `setup_multiple_fmha_dispatchers()` which implements a 3-stage pipelined build:
|
||||
|
||||
1. **Codegen** (parallel) -- generate C++ kernel specializations and ctypes wrappers
|
||||
2. **Compile** (parallel) -- `hipcc` compile each kernel and ctypes lib
|
||||
3. **Link + Load** (parallel) -- produce `.so` libraries, load via ctypes
|
||||
|
||||
With 256 workers, throughput is roughly 5-10 kernels/sec depending on kernel complexity.
|
||||
|
||||
## JSON Config Format
|
||||
|
||||
Each config specifies a `variant` and an optional `trait_config` that acts as an allow-list filter:
|
||||
|
||||
```json
|
||||
{
|
||||
"variant": "fwd",
|
||||
"trait_config": {
|
||||
"data_type": {"values": ["fp16", "bf16"]},
|
||||
"pipeline": {"values": ["qr_async"]},
|
||||
"mode": {"values": ["batch"]},
|
||||
"mask": {"values": ["no"]},
|
||||
"bias": {"values": ["no"]},
|
||||
"lse": {"values": [false]},
|
||||
"dropout": {"values": [false]},
|
||||
"logits": {"values": [false]},
|
||||
"sink": {"values": [false]}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
If a trait key is absent, all values pass. The `receipt0_fwd.json` config only restricts `data_type` to exclude fp32, giving the full ~12K forward kernel set.
|
||||
|
||||
## Filtering
|
||||
|
||||
### CLI expression
|
||||
|
||||
```bash
|
||||
python fmha_benchmark.py configs/receipt0_fwd.json \
|
||||
--filter "c.hdim_q == 128 and c.pipeline == 'qr_async'"
|
||||
|
||||
python fmha_full_benchmark.py --variant fwd \
|
||||
--filter "c.hdim_q == 128 and c.hdim_v == 128 and c.data_type == 'fp16'"
|
||||
```
|
||||
|
||||
The expression accesses `c` (an `FmhaKernelConfig` dataclass) with fields: `data_type`, `mode`, `hdim_q`, `hdim_v`, `pipeline`, `tile_m0`, `tile_n0`, `tile_k0`, `pad_s`, `pad_sk`, `pad_d`, `pad_dv`, `mask`, `bias`, `lse`, `dropout`, `logits`, `sink`, `skip_min_seqlen_q`, `qscale`, `paged_kv`, `rope`, `deterministic`, `dbias`, `dropout_variant`.
|
||||
|
||||
### Python file filter
|
||||
|
||||
```bash
|
||||
python fmha_benchmark.py configs/receipt0_fwd.json --filter-file filters/h128_no_dropout.py
|
||||
```
|
||||
|
||||
The file must define `filter_config(c) -> bool`. Both `--filter` and `--filter-file` combine with AND logic.
|
||||
|
||||
## Test Shape Matrix
|
||||
|
||||
`ck_fmha_testing_matrix.yaml` defines test problems in three tiers:
|
||||
|
||||
| Category | Purpose | Shapes |
|
||||
|----------|---------|--------|
|
||||
| `smoke` | Pre-submit sanity, <5 min | ~365 |
|
||||
| `full` | Post-submit validation | smoke + ~1,500 |
|
||||
| `nightly`| Exhaustive sweep | all |
|
||||
|
||||
Shapes cover representative configurations: GQA ratios, asymmetric head dims, non-power-of-2 sequences, FP8 variants, long sequences, and cross-attention patterns.
|
||||
|
||||
## Output Format
|
||||
|
||||
### CSV
|
||||
|
||||
```
|
||||
problem_name,batch,seqlen_q,seqlen_k,nhead_q,nhead_k,hdim_q,hdim_v,dtype,
|
||||
kernel,family,mode,pipeline,tile_m0,tile_n0,tile_k0,...,
|
||||
latency_ms,tflops,bandwidth_gb_s
|
||||
```
|
||||
|
||||
Every column needed to fully reconstruct the kernel identity is included. TFLOPS and latency come directly from CK's internal HIP event timing.
|
||||
|
||||
### JSON
|
||||
|
||||
```json
|
||||
{
|
||||
"metadata": {
|
||||
"arch": "gfx950",
|
||||
"category": "smoke",
|
||||
"total_kernels": 600,
|
||||
"shapes_benchmarked": 42,
|
||||
"total_measurements": 12600
|
||||
},
|
||||
"results": [...]
|
||||
}
|
||||
```
|
||||
|
||||
## CMake Targets
|
||||
|
||||
```bash
|
||||
make benchmark_fmha # Forward sweep
|
||||
make benchmark_fmha_ci # Quick CI validation
|
||||
make benchmark_fmha_bwd # Backward sweep
|
||||
make benchmark_fmha_all # All variants
|
||||
make benchmark_fmha_splitkv # Split-KV only
|
||||
```
|
||||
|
||||
## Parity Verification
|
||||
|
||||
```bash
|
||||
python dispatcher/tests/validate_arch_specs_parity.py --arch gfx950 --receipt 0
|
||||
# PASS: 33,541 kernels across all 9 families
|
||||
```
|
||||
|
||||
This confirms the dispatcher's self-contained enumeration exactly matches CK Tile's upstream codegen.
|
||||
|
||||
## Example: Single-Shape All-Kernel Benchmark
|
||||
|
||||
Run every compiled fwd fp16 h128 kernel against one shape:
|
||||
|
||||
```bash
|
||||
python fmha_full_benchmark.py \
|
||||
--category smoke --variant fwd --workers 256 \
|
||||
--filter "c.hdim_q == 128 and c.hdim_v == 128 and c.data_type == 'fp16'" \
|
||||
--csv results.csv
|
||||
```
|
||||
788
tile_engine/ops/fmha/ck_fmha_testing_matrix.yaml
Normal file
788
tile_engine/ops/fmha/ck_fmha_testing_matrix.yaml
Normal file
@@ -0,0 +1,788 @@
|
||||
test_categories:
|
||||
Smoke:
|
||||
description: "Pre-submit sanity checks. Fast execution, covering basic functionality and edge cases."
|
||||
test_patterns:
|
||||
- "*/Smoke.*"
|
||||
labels: ["Smoke"]
|
||||
|
||||
Full:
|
||||
description: "Post-submit validation. Comprehensive coverage of modern LLM architectures and CK operational constraints."
|
||||
test_patterns:
|
||||
- "*/Smoke.*"
|
||||
- "*/Full.*"
|
||||
labels: ["Full"]
|
||||
|
||||
Nightly:
|
||||
description: "Nightly exhaustive coverage. Sweeps all combinations of precision, layout, masking, and padding."
|
||||
test_patterns:
|
||||
- "*"
|
||||
labels: ["Nightly"]
|
||||
|
||||
execution_settings:
|
||||
default_timeout: 60
|
||||
category_timeouts:
|
||||
Smoke: 60 # 1 min per test
|
||||
Full: 300 # 5 min per test
|
||||
Nightly: 600 # 10 min per test
|
||||
|
||||
# =============================================================================
|
||||
# Forward Pass (Prefill) & Stochastic Execution (Dropout)
|
||||
# =============================================================================
|
||||
forward_tests:
|
||||
# ---------------------------------------------------------------------------
|
||||
# Smoke Tests (Fast, representative subset)
|
||||
# ---------------------------------------------------------------------------
|
||||
smoke:
|
||||
- name: "GQA_4to1_Prefill_Basic"
|
||||
description: "Baseline GQA prefill; primary optimization target."
|
||||
batch: [1, 4]
|
||||
seqlen_q: [2048]
|
||||
seqlen_k: [2048]
|
||||
nhead_q: [32]
|
||||
nhead_k: [8]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask", "top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false, true]
|
||||
|
||||
- name: "Small_GQA_7to1_SubWarp"
|
||||
description: "Sub-warp vectorized loads; low LDS utilization bounds."
|
||||
batch: [1]
|
||||
seqlen_q: [1024]
|
||||
seqlen_k: [1024]
|
||||
nhead_q: [14]
|
||||
nhead_k: [2]
|
||||
hdim_q: [64]
|
||||
hdim_v: [64]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "MHA_H96_Irregular_Dim"
|
||||
description: "Non-power-of-2 hdim; forces complex padding/striding in LDS."
|
||||
batch: [2]
|
||||
seqlen_q: [1024]
|
||||
seqlen_k: [1024]
|
||||
nhead_q: [32]
|
||||
nhead_k: [32]
|
||||
hdim_q: [96]
|
||||
hdim_v: [96]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
# CK smoke test edge cases (from example/ck_tile/01_fmha/script/smoke_test_fwd.sh)
|
||||
- name: "CK_Asymmetric_Hdim_Small"
|
||||
description: "Asymmetric hdim_q != hdim_v; tests vectorized load widths."
|
||||
batch: [2]
|
||||
seqlen_q: [55]
|
||||
seqlen_k: [256]
|
||||
nhead_q: [2]
|
||||
nhead_k: [1]
|
||||
hdim_q: [16]
|
||||
hdim_v: [32, 64, 128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "CK_Tiny_Sequences"
|
||||
description: "Edge cases: sq=1, sq=3, very short sequences."
|
||||
batch: [1, 2]
|
||||
seqlen_q: [1, 3, 33]
|
||||
seqlen_k: [10, 99, 33]
|
||||
nhead_q: [2]
|
||||
nhead_k: [1]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask", "top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "CK_Asymmetric_Seqlen"
|
||||
description: "Asymmetric seqlen_q != seqlen_k from CK smoke tests."
|
||||
batch: [1, 2]
|
||||
seqlen_q: [100, 99, 1024]
|
||||
seqlen_k: [51, 256, 256]
|
||||
nhead_q: [3]
|
||||
nhead_k: [3]
|
||||
hdim_q: [64, 128]
|
||||
hdim_v: [64, 128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask", "top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
# Hdim sweep covering all supported (hdim_q, hdim_v) pairs.
|
||||
# YAML cartesian product creates some orphan combos (hdim_q != hdim_v pairs
|
||||
# without kernels). The benchmark silently skips these. Use --validate to list them.
|
||||
# Supported pairs: h32, h64, h80x96, h96, h96x128, h128, h160, h192x128, h192, h256
|
||||
- name: "CK_All_Hdim_Sweep"
|
||||
description: "Sweep all supported hdim combos. Orphan pairs are skipped at runtime."
|
||||
batch: [2]
|
||||
seqlen_q: [1024]
|
||||
seqlen_k: [1024]
|
||||
nhead_q: [8]
|
||||
nhead_k: [4]
|
||||
hdim_q: [32, 64, 80, 96, 128, 160, 192, 256]
|
||||
hdim_v: [32, 64, 96, 128, 160, 192, 256]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "CK_FP8_Basic"
|
||||
description: "FP8 basic forward test."
|
||||
batch: [1, 2]
|
||||
seqlen_q: [128]
|
||||
seqlen_k: [128]
|
||||
nhead_q: [1]
|
||||
nhead_k: [1]
|
||||
hdim_q: [64, 128, 192, 256]
|
||||
hdim_v: [64, 128, 128, 256]
|
||||
dtype: ["fp8bf16", "fp8fp32"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
# Production model configs (from aiter model_shapes.json)
|
||||
- name: "GQA_16to1_Large"
|
||||
description: "16:1 GQA ratio (70B-class models)."
|
||||
batch: [1, 4]
|
||||
seqlen_q: [2048]
|
||||
seqlen_k: [2048]
|
||||
nhead_q: [64]
|
||||
nhead_k: [4]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask", "top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "MQA_128to8_Decode"
|
||||
description: "405B-class decode: 128 Q heads, 8 KV heads, single token query."
|
||||
batch: [1, 8, 64]
|
||||
seqlen_q: [1]
|
||||
seqlen_k: [1024, 4096]
|
||||
nhead_q: [128]
|
||||
nhead_k: [8]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "MLA_Sparse_Decode"
|
||||
description: "Multi-latent attention decode (R1-class): asymmetric h192x128."
|
||||
batch: [1, 4]
|
||||
seqlen_q: [1]
|
||||
seqlen_k: [1024, 4096]
|
||||
nhead_q: [128]
|
||||
nhead_k: [128]
|
||||
hdim_q: [192]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "Vision_Transformer_Shapes"
|
||||
description: "Vision-text hybrid (Maverick-class): h88 and h128 mixed."
|
||||
batch: [1, 4]
|
||||
seqlen_q: [256, 1024]
|
||||
seqlen_k: [256, 1024]
|
||||
nhead_q: [16, 40]
|
||||
nhead_k: [8, 16]
|
||||
hdim_q: [88, 128]
|
||||
hdim_v: [88, 128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "FP8_Varlen_Realistic"
|
||||
description: "FP8 with realistic GQA and variable lengths (from aiter tests)."
|
||||
batch: [1, 8]
|
||||
seqlen_q: [113, 256, 1024]
|
||||
seqlen_k: [203, 512, 1024]
|
||||
nhead_q: [8, 32, 40]
|
||||
nhead_k: [1, 8]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp8bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "Extreme_GQA_Ratios"
|
||||
description: "Extreme GQA: 5:1, 10:1, 24:4, 48:8 from aiter test suite."
|
||||
batch: [2]
|
||||
seqlen_q: [1024]
|
||||
seqlen_k: [1024]
|
||||
nhead_q: [5, 10, 24, 48]
|
||||
nhead_k: [1, 1, 4, 8]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask", "top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "Paged_Decode_Shapes"
|
||||
description: "Paged attention decode patterns: single-token Q, long KV context."
|
||||
batch: [4, 80, 128]
|
||||
seqlen_q: [1, 4]
|
||||
seqlen_k: [512, 4096]
|
||||
nhead_q: [8, 16, 64]
|
||||
nhead_k: [1, 4]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "Prefill_Odd_Lengths"
|
||||
description: "Prefill with non-standard seq lengths from aiter test suite."
|
||||
batch: [2]
|
||||
seqlen_q: [113, 339, 799, 1023, 3131]
|
||||
seqlen_k: [203, 339, 799, 1024, 3131]
|
||||
nhead_q: [32]
|
||||
nhead_k: [8]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full Tests (Modern LLM Architectures & CK Constraints)
|
||||
# ---------------------------------------------------------------------------
|
||||
full:
|
||||
- name: "MHA_H256_High_LDS_Pressure"
|
||||
description: "High LDS pressure; tests block partitioner limits with hdim=256."
|
||||
batch: [1, 4]
|
||||
seqlen_q: [4096]
|
||||
seqlen_k: [4096]
|
||||
nhead_q: [8]
|
||||
nhead_k: [4]
|
||||
hdim_q: [256]
|
||||
hdim_v: [256]
|
||||
dtype: ["bf16"]
|
||||
layout: ["BHSD", "BSHD"]
|
||||
mask: ["no_mask", "top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [true]
|
||||
|
||||
- name: "MQA_64to1_Broadcast"
|
||||
description: "Pure MQA; tests extreme KV to Q broadcast logic (64:1)."
|
||||
batch: [2]
|
||||
seqlen_q: [4096]
|
||||
seqlen_k: [4096]
|
||||
nhead_q: [64]
|
||||
nhead_k: [1]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "GQA_6to1_Irregular"
|
||||
description: "Irregular 6:1 GQA ratio; tests tile distribution."
|
||||
batch: [2]
|
||||
seqlen_q: [4096]
|
||||
seqlen_k: [4096]
|
||||
nhead_q: [48]
|
||||
nhead_k: [8]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "MLA_H128xH576_Asymmetric"
|
||||
description: "Multi-latent attention fusion; asymmetric Q/KV (128 vs 576)."
|
||||
batch: [1, 4]
|
||||
seqlen_q: [4096]
|
||||
seqlen_k: [4096]
|
||||
nhead_q: [128]
|
||||
nhead_k: [128]
|
||||
hdim_q: [128]
|
||||
hdim_v: [576]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [true]
|
||||
|
||||
- name: "Asymmetric_Head_Dims_192_128"
|
||||
description: "Test asymmetric head dimensions (192x128)."
|
||||
batch: [2]
|
||||
seqlen_q: [2048]
|
||||
seqlen_k: [2048]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [192]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD", "BSHD"]
|
||||
mask: ["no_mask", "top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "Asymmetric_Head_Dims_128_192"
|
||||
description: "Test asymmetric head dimensions (128x192)."
|
||||
batch: [2]
|
||||
seqlen_q: [2048]
|
||||
seqlen_k: [2048]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [128]
|
||||
hdim_v: [192]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "Diverse_Head_Dims_Sweep"
|
||||
description: "Sweep across various head dimensions to ensure broad coverage."
|
||||
batch: [2]
|
||||
seqlen_q: [1024]
|
||||
seqlen_k: [1024]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [48, 64, 72, 96, 128, 160, 256]
|
||||
hdim_v: [48, 64, 72, 96, 128, 160, 256]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "Stochastic_Execution_Dropout_Sweep"
|
||||
description: "PRNG state synchronization and warp alignment with stochastic masking across dims."
|
||||
batch: [4]
|
||||
seqlen_q: [1024]
|
||||
seqlen_k: [1024]
|
||||
nhead_q: [16]
|
||||
nhead_k: [8]
|
||||
hdim_q: [48, 64, 72, 96, 128, 160, 256]
|
||||
hdim_v: [48, 64, 72, 96, 128, 160, 256]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.1, 0.2]
|
||||
lse: [false, true]
|
||||
|
||||
- name: "Padding_Boundary_Stress_Odd_Lengths"
|
||||
description: "Test sequences that are not perfect multiples of the tile size to validate padding logic."
|
||||
batch: [2]
|
||||
seqlen_q: [259, 500, 987, 1023]
|
||||
seqlen_k: [259, 500, 987, 1023]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask", "top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "Bias_Variants_Sweep"
|
||||
description: "Test elementwise and alibi bias across different sequence lengths and batch sizes."
|
||||
batch: [1, 4]
|
||||
seqlen_q: [512, 1024]
|
||||
seqlen_k: [512, 1024]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [64, 128]
|
||||
hdim_v: [64, 128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask", "top_left"]
|
||||
bias: ["elementwise", "alibi"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "Extreme_Batch_Size_Stress"
|
||||
description: "Test very large batch sizes to stress grid launch dimensions and scheduling."
|
||||
batch: [64, 128, 256]
|
||||
seqlen_q: [128]
|
||||
seqlen_k: [128]
|
||||
nhead_q: [8]
|
||||
nhead_k: [8]
|
||||
hdim_q: [64]
|
||||
hdim_v: [64]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
- name: "Long_Sequence_Stress"
|
||||
description: "Test very long sequences (approaching split-KV territory but forced dense)."
|
||||
batch: [1]
|
||||
seqlen_q: [8192, 16384]
|
||||
seqlen_k: [8192, 16384]
|
||||
nhead_q: [16]
|
||||
nhead_k: [4]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [true]
|
||||
|
||||
- name: "Cross_Attention_Shapes"
|
||||
description: "Test shapes typical of cross-attention where seqlen_q != seqlen_k."
|
||||
batch: [2]
|
||||
seqlen_q: [1, 32, 128]
|
||||
seqlen_k: [1024, 4096]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
|
||||
- name: "CK_Benchmark_Standard"
|
||||
description: "Standard CK benchmark sweep (from benchmark_fwd.sh)."
|
||||
batch: [32, 16, 8, 4, 2, 1]
|
||||
seqlen_q: [512, 1024, 2048, 4096, 8192, 16384]
|
||||
seqlen_k: [512, 1024, 2048, 4096, 8192, 16384]
|
||||
nhead_q: [32, 16, 8]
|
||||
nhead_k: [32, 16, 8]
|
||||
hdim_q: [64, 128, 256]
|
||||
hdim_v: [64, 128, 256]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
|
||||
- name: "CK_Benchmark_V3_Large"
|
||||
description: "V3 pipeline benchmark with very long sequences (from benchmark_fwd_v3.sh)."
|
||||
batch: [1]
|
||||
seqlen_q: [16384, 37200, 65536]
|
||||
seqlen_k: [16384, 37200, 65536]
|
||||
nhead_q: [16, 40, 64]
|
||||
nhead_k: [1, 16, 40, 64]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask", "top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
lse: [false]
|
||||
|
||||
# =============================================================================
|
||||
# Backward Pass (Gradient Computation)
|
||||
# =============================================================================
|
||||
backward_tests:
|
||||
# ---------------------------------------------------------------------------
|
||||
# Smoke Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
smoke:
|
||||
- name: "Bwd_Basic_No_Features"
|
||||
description: "Basic backward pass without optional features."
|
||||
batch: [1, 2]
|
||||
seqlen_q: [512]
|
||||
seqlen_k: [512]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
has_dbias: [false]
|
||||
is_deterministic: [false]
|
||||
|
||||
- name: "Bwd_GQA_Smoke"
|
||||
description: "Backward GQA smoke test (4:1 and 8:1 ratios)."
|
||||
batch: [2]
|
||||
seqlen_q: [1024]
|
||||
seqlen_k: [1024]
|
||||
nhead_q: [32]
|
||||
nhead_k: [8]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask", "top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
has_dbias: [false]
|
||||
is_deterministic: [false]
|
||||
|
||||
- name: "Bwd_Hdim_Sweep_Smoke"
|
||||
description: "Backward across key head dimensions."
|
||||
batch: [2]
|
||||
seqlen_q: [512]
|
||||
seqlen_k: [512]
|
||||
nhead_q: [8]
|
||||
nhead_k: [8]
|
||||
hdim_q: [64, 96, 128, 256]
|
||||
hdim_v: [64, 96, 128, 256]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
has_dbias: [false]
|
||||
is_deterministic: [false]
|
||||
|
||||
- name: "Bwd_With_Mask_Dropout"
|
||||
description: "Backward with causal mask and dropout."
|
||||
batch: [2]
|
||||
seqlen_q: [512]
|
||||
seqlen_k: [512]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [64, 128]
|
||||
hdim_v: [64, 128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.1]
|
||||
has_dbias: [false]
|
||||
is_deterministic: [false]
|
||||
|
||||
- name: "Bwd_Asymmetric_Hdim_Smoke"
|
||||
description: "Backward with asymmetric head dimensions."
|
||||
batch: [2]
|
||||
seqlen_q: [512]
|
||||
seqlen_k: [512]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [192]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
has_dbias: [false]
|
||||
is_deterministic: [false]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
full:
|
||||
- name: "Bwd_GQA_Support"
|
||||
description: "Backward pass with Grouped Query Attention."
|
||||
batch: [2]
|
||||
seqlen_q: [1024]
|
||||
seqlen_k: [1024]
|
||||
nhead_q: [32, 64]
|
||||
nhead_k: [8]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
has_dbias: [false]
|
||||
is_deterministic: [false]
|
||||
|
||||
- name: "Bwd_High_Capacity_H256"
|
||||
description: "Backward pass with hdim=256; high LDS pressure."
|
||||
batch: [1]
|
||||
seqlen_q: [1024]
|
||||
seqlen_k: [1024]
|
||||
nhead_q: [8]
|
||||
nhead_k: [4]
|
||||
hdim_q: [256]
|
||||
hdim_v: [256]
|
||||
dtype: ["bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask", "top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
has_dbias: [false]
|
||||
is_deterministic: [false]
|
||||
|
||||
- name: "Bwd_Irregular_H96"
|
||||
description: "Backward pass with non-power-of-2 hdim."
|
||||
batch: [2]
|
||||
seqlen_q: [1024]
|
||||
seqlen_k: [1024]
|
||||
nhead_q: [32]
|
||||
nhead_k: [32]
|
||||
hdim_q: [96]
|
||||
hdim_v: [96]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
has_dbias: [false]
|
||||
is_deterministic: [false]
|
||||
|
||||
- name: "Bwd_All_Features_Enabled"
|
||||
description: "Backward pass with bias gradients, dropout, and deterministic accumulation."
|
||||
batch: [2]
|
||||
seqlen_q: [512]
|
||||
seqlen_k: [512]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [48, 64, 72, 96, 128, 160, 256]
|
||||
hdim_v: [48, 64, 72, 96, 128, 160, 256]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["top_left"]
|
||||
bias: ["elementwise", "alibi"]
|
||||
dropout: [0.1]
|
||||
has_dbias: [true]
|
||||
is_deterministic: [true]
|
||||
|
||||
- name: "Bwd_Padding_Boundary_Stress"
|
||||
description: "Test backward pass with sequences that are not perfect multiples of the tile size."
|
||||
batch: [1]
|
||||
seqlen_q: [259, 500, 1023]
|
||||
seqlen_k: [259, 500, 1023]
|
||||
nhead_q: [8]
|
||||
nhead_k: [8]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask", "top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
has_dbias: [false]
|
||||
is_deterministic: [false]
|
||||
|
||||
- name: "Bwd_Asymmetric_Head_Dims_192_128"
|
||||
description: "Test backward pass with asymmetric head dimensions (192x128)."
|
||||
batch: [2]
|
||||
seqlen_q: [1024]
|
||||
seqlen_k: [1024]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [192]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["top_left"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
has_dbias: [false]
|
||||
is_deterministic: [false]
|
||||
|
||||
- name: "Bwd_Asymmetric_Head_Dims_128_192"
|
||||
description: "Test backward pass with asymmetric head dimensions (128x192)."
|
||||
batch: [2]
|
||||
seqlen_q: [1024]
|
||||
seqlen_k: [1024]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [128]
|
||||
hdim_v: [192]
|
||||
dtype: ["fp16", "bf16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
has_dbias: [false]
|
||||
is_deterministic: [false]
|
||||
|
||||
- name: "Bwd_Diverse_Head_Dims_Sweep"
|
||||
description: "Sweep backward pass across various head dimensions."
|
||||
batch: [2]
|
||||
seqlen_q: [512]
|
||||
seqlen_k: [512]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [48, 64, 72, 96, 128, 160, 256]
|
||||
hdim_v: [48, 64, 72, 96, 128, 160, 256]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
has_dbias: [false]
|
||||
is_deterministic: [false]
|
||||
|
||||
- name: "Bwd_Cross_Attention_Shapes"
|
||||
description: "Test shapes typical of cross-attention where seqlen_q != seqlen_k in backward."
|
||||
batch: [2]
|
||||
seqlen_q: [1, 32, 128]
|
||||
seqlen_k: [1024, 4096]
|
||||
nhead_q: [16]
|
||||
nhead_k: [16]
|
||||
hdim_q: [128]
|
||||
hdim_v: [128]
|
||||
dtype: ["fp16"]
|
||||
layout: ["BHSD"]
|
||||
mask: ["no_mask"]
|
||||
bias: ["none"]
|
||||
dropout: [0.0]
|
||||
has_dbias: [false]
|
||||
is_deterministic: [false]
|
||||
6
tile_engine/ops/fmha/configs/appendkv.json
Normal file
6
tile_engine/ops/fmha/configs/appendkv.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"variant": "appendkv",
|
||||
"trait_config": {
|
||||
"data_type": {"values": ["fp16", "bf16", "fp8"]}
|
||||
}
|
||||
}
|
||||
6
tile_engine/ops/fmha/configs/batch_prefill.json
Normal file
6
tile_engine/ops/fmha/configs/batch_prefill.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"variant": "batch_prefill",
|
||||
"trait_config": {
|
||||
"data_type": {"values": ["fp16", "bf16", "fp8bf16"]}
|
||||
}
|
||||
}
|
||||
6
tile_engine/ops/fmha/configs/bwd.json
Normal file
6
tile_engine/ops/fmha/configs/bwd.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"variant": "bwd",
|
||||
"trait_config": {
|
||||
"data_type": {"values": ["fp16", "bf16"]}
|
||||
}
|
||||
}
|
||||
9
tile_engine/ops/fmha/configs/fwd.json
Normal file
9
tile_engine/ops/fmha/configs/fwd.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"variant": "fwd",
|
||||
"trait_config": {
|
||||
"data_type": {"values": ["fp16", "bf16"]},
|
||||
"pipeline": {"values": ["qr", "qr_async"]},
|
||||
"mask": {"values": ["no", "top_left"]},
|
||||
"bias": {"values": ["no"]}
|
||||
}
|
||||
}
|
||||
14
tile_engine/ops/fmha/configs/fwd_ci.json
Normal file
14
tile_engine/ops/fmha/configs/fwd_ci.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"variant": "fwd",
|
||||
"trait_config": {
|
||||
"data_type": {"values": ["fp16"]},
|
||||
"pipeline": {"values": ["qr_async"]},
|
||||
"mask": {"values": ["no"]},
|
||||
"bias": {"values": ["no"]},
|
||||
"mode": {"values": ["batch"]},
|
||||
"lse": {"values": [false]},
|
||||
"dropout": {"values": [false]},
|
||||
"logits": {"values": [false]},
|
||||
"sink": {"values": [false]}
|
||||
}
|
||||
}
|
||||
6
tile_engine/ops/fmha/configs/pagedkv.json
Normal file
6
tile_engine/ops/fmha/configs/pagedkv.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"variant": "pagedkv",
|
||||
"trait_config": {
|
||||
"data_type": {"values": ["fp16", "bf16", "fp8"]}
|
||||
}
|
||||
}
|
||||
6
tile_engine/ops/fmha/configs/receipt0_fwd.json
Normal file
6
tile_engine/ops/fmha/configs/receipt0_fwd.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"variant": "fwd",
|
||||
"trait_config": {
|
||||
"data_type": {"values": ["fp16", "bf16", "fp8bf16", "fp8fp32"]}
|
||||
}
|
||||
}
|
||||
6
tile_engine/ops/fmha/configs/splitkv.json
Normal file
6
tile_engine/ops/fmha/configs/splitkv.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"variant": "splitkv",
|
||||
"trait_config": {
|
||||
"data_type": {"values": ["fp16", "bf16", "fp8"]}
|
||||
}
|
||||
}
|
||||
14
tile_engine/ops/fmha/filters/h128_no_dropout.py
Normal file
14
tile_engine/ops/fmha/filters/h128_no_dropout.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Sample filter: only h128 kernels without dropout.
|
||||
|
||||
Usage:
|
||||
python fmha_benchmark.py configs/receipt0_fwd.json --filter-file filters/h128_no_dropout.py
|
||||
python fmha_instance_builder.py configs/receipt0_fwd.json --filter-file filters/h128_no_dropout.py --count-only
|
||||
"""
|
||||
|
||||
|
||||
def filter_config(c) -> bool:
|
||||
"""Keep only h128 kernels without dropout."""
|
||||
return c.hdim_q == 128 and not c.dropout
|
||||
939
tile_engine/ops/fmha/fmha_benchmark.py
Normal file
939
tile_engine/ops/fmha/fmha_benchmark.py
Normal file
@@ -0,0 +1,939 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
FMHA tile engine benchmark runner.
|
||||
|
||||
Uses the dispatcher's setup_multiple_fmha_dispatchers() for pipelined JIT
|
||||
compilation, then runs GPU benchmarks and reports results.
|
||||
|
||||
Usage:
|
||||
python fmha_benchmark.py configs/fwd.json
|
||||
python fmha_benchmark.py configs/fwd.json --workers 256 --build-dir /tmp/fmha_build
|
||||
python fmha_benchmark.py configs/fwd.json --problems "2,8,1024,128" --verify
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
_DISPATCHER_ROOT = Path(__file__).resolve().parents[3] / "dispatcher"
|
||||
sys.path.insert(0, str(_DISPATCHER_ROOT / "python"))
|
||||
sys.path.insert(0, str(_DISPATCHER_ROOT / "codegen"))
|
||||
|
||||
from fmha_utils import ( # noqa: E402
|
||||
FmhaProblem,
|
||||
FmhaRunner,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_multiple_fmha_dispatchers,
|
||||
)
|
||||
|
||||
from fmha.instance_gen import expand_sweep, apply_filter # noqa: E402
|
||||
|
||||
# Reusable multi-GPU job dispatcher (op-agnostic)
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "common"))
|
||||
from parallel_runner import run_parallel_on_gpus # noqa: E402
|
||||
|
||||
|
||||
def _compute_result(
|
||||
config,
|
||||
prob,
|
||||
time_ms,
|
||||
output,
|
||||
ref,
|
||||
is_causal,
|
||||
ns,
|
||||
api_family,
|
||||
dtype_tol,
|
||||
gpu_id=None,
|
||||
):
|
||||
"""Compute tflops, max_err, status and build result dict + display line.
|
||||
|
||||
Returns (result_dict, display_line) or None if time_ms is None/0.
|
||||
"""
|
||||
tflops = prob.num_ops / (time_ms * 1e-3) / 1e12 if time_ms > 0 else 0
|
||||
if is_causal and time_ms > 0:
|
||||
sq, sk = prob.seqlen_q, prob.seqlen_k
|
||||
causal_ratio = (min(sq, sk) + 1) / (2.0 * sk)
|
||||
tflops = prob.num_ops * causal_ratio / (time_ms * 1e-3) / 1e12
|
||||
|
||||
max_err = 0.0
|
||||
status = "OK"
|
||||
if ref is not None and output is not None:
|
||||
max_err = float(np.abs(output.astype(np.float32) - ref).max())
|
||||
atol, rtol = dtype_tol
|
||||
tol = atol + rtol * np.abs(ref).max()
|
||||
status = "PASS" if max_err < tol else "FAIL"
|
||||
|
||||
splits_tag = f" [ns={ns}]" if api_family == "splitkv" else ""
|
||||
display_name = f"{config.name}{splits_tag}"
|
||||
gpu_tag = f" [GPU{gpu_id}]" if gpu_id is not None else ""
|
||||
display_line = (
|
||||
f" {display_name:<105} {time_ms:>10.3f}"
|
||||
f" {tflops:>10.2f} {max_err:>10.2e} {status:>6}{gpu_tag}"
|
||||
)
|
||||
|
||||
result_dict = {
|
||||
"kernel": config.name,
|
||||
"dtype": config.data_type,
|
||||
"hdim_q": config.hdim_q,
|
||||
"hdim_v": config.hdim_v,
|
||||
"pipeline": config.pipeline,
|
||||
"mode": config.mode,
|
||||
"mask": config.mask,
|
||||
"bias": config.bias,
|
||||
"tile_m0": config.tile_m0,
|
||||
"tile_n0": config.tile_n0,
|
||||
"tile_k0": config.tile_k0,
|
||||
"tile_n1": config.tile_n1,
|
||||
"tile_k1": config.tile_k1,
|
||||
"tile_k0max": config.tile_k0max,
|
||||
"warp_m0": config.warp_m0,
|
||||
"warp_n0": config.warp_n0,
|
||||
"warp_k0": config.warp_k0,
|
||||
"block_per_cu": config.block_per_cu,
|
||||
"num_splits": ns if api_family == "splitkv" else None,
|
||||
"problem": {
|
||||
"batch": prob.batch,
|
||||
"nhead_q": prob.nhead_q,
|
||||
"nhead_k": prob.nhead_k,
|
||||
"seqlen_q": prob.seqlen_q,
|
||||
"seqlen_k": prob.seqlen_k,
|
||||
"hdim_q": prob.hdim_q,
|
||||
"hdim_v": prob.hdim_v,
|
||||
},
|
||||
"latency_ms": time_ms,
|
||||
"tflops": tflops,
|
||||
"max_err": max_err,
|
||||
"status": status,
|
||||
}
|
||||
return result_dict, display_line
|
||||
|
||||
|
||||
def _run_kernel_isolated(
|
||||
lib_path, arch, prob, run_kwargs, data_dir, gpu_id=0, timeout=120
|
||||
):
|
||||
"""Run a single kernel in a subprocess. Returns (time_ms, output_path) or (None, error_msg).
|
||||
|
||||
Survives GPU faults — if the subprocess crashes, returns an error instead of killing main.
|
||||
"""
|
||||
import json as _json
|
||||
import subprocess as sp
|
||||
|
||||
# Write a small runner script that the subprocess will execute.
|
||||
# Use json.dumps for string values to safely escape quotes/backslashes in paths.
|
||||
_lib = _json.dumps(str(lib_path))
|
||||
_arch = _json.dumps(str(arch))
|
||||
_pydir = _json.dumps(str(_DISPATCHER_ROOT / "python"))
|
||||
_ddir = _json.dumps(str(data_dir))
|
||||
script = f'''
|
||||
import sys, os, numpy as np
|
||||
os.environ["HIP_VISIBLE_DEVICES"] = "{gpu_id}"
|
||||
sys.path.insert(0, {_pydir})
|
||||
from fmha_utils import FmhaRunner, FmhaProblem
|
||||
|
||||
runner = FmhaRunner.from_library({_lib}, {_arch})
|
||||
_d = {_ddir}
|
||||
Q = np.load(os.path.join(_d, "Q.npy"))
|
||||
K = np.load(os.path.join(_d, "K.npy"))
|
||||
V = np.load(os.path.join(_d, "V.npy"))
|
||||
prob = FmhaProblem(batch={prob.batch}, nhead_q={prob.nhead_q}, nhead_k={prob.nhead_k},
|
||||
seqlen_q={prob.seqlen_q}, seqlen_k={prob.seqlen_k},
|
||||
hdim_q={prob.hdim_q}, hdim_v={prob.hdim_v})
|
||||
result = runner.run(Q, K, V, prob, **{run_kwargs!r})
|
||||
if result.success:
|
||||
np.save(os.path.join(_d, "O.npy"), result.output)
|
||||
print(f"TIME={{result.time_ms}}")
|
||||
else:
|
||||
print("FAIL")
|
||||
runner.cleanup()
|
||||
'''
|
||||
script_path = os.path.join(data_dir, "run_kernel.py")
|
||||
with open(script_path, "w") as f:
|
||||
f.write(script)
|
||||
|
||||
try:
|
||||
r = sp.run(
|
||||
[sys.executable, script_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
env={**os.environ, "HIP_VISIBLE_DEVICES": str(gpu_id)},
|
||||
)
|
||||
if r.returncode != 0:
|
||||
err = r.stderr[-200:] if r.stderr else f"exit code {r.returncode}"
|
||||
return None, None, f"CRASH: {err.strip()}"
|
||||
# Parse time from stdout
|
||||
for line in r.stdout.strip().split("\n"):
|
||||
if line.startswith("TIME="):
|
||||
time_ms = float(line[5:])
|
||||
out_path = os.path.join(data_dir, "O.npy")
|
||||
output = np.load(out_path) if os.path.exists(out_path) else None
|
||||
return time_ms, output, None
|
||||
return None, None, "No TIME output"
|
||||
except sp.TimeoutExpired:
|
||||
return None, None, "TIMEOUT"
|
||||
except Exception as e:
|
||||
return None, None, str(e)
|
||||
|
||||
|
||||
def parse_problems(spec: str) -> List[FmhaProblem]:
|
||||
"""Parse problem specs: 'batch,nhead,seqlen,hdim;...'"""
|
||||
problems = []
|
||||
for part in spec.split(";"):
|
||||
vals = [int(x) for x in part.split(",")]
|
||||
if len(vals) == 4:
|
||||
b, h, s, d = vals
|
||||
problems.append(
|
||||
FmhaProblem(
|
||||
batch=b,
|
||||
nhead_q=h,
|
||||
nhead_k=h,
|
||||
seqlen_q=s,
|
||||
seqlen_k=s,
|
||||
hdim_q=d,
|
||||
hdim_v=d,
|
||||
)
|
||||
)
|
||||
elif len(vals) == 6:
|
||||
b, hq, hk, sq, sk, d = vals
|
||||
problems.append(
|
||||
FmhaProblem(
|
||||
batch=b,
|
||||
nhead_q=hq,
|
||||
nhead_k=hk,
|
||||
seqlen_q=sq,
|
||||
seqlen_k=sk,
|
||||
hdim_q=d,
|
||||
hdim_v=d,
|
||||
)
|
||||
)
|
||||
return problems
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="FMHA Tile Engine Benchmark")
|
||||
parser.add_argument(
|
||||
"configs", nargs="*", help="Sweep config JSON(s) (optional for exhaustive)"
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=os.cpu_count() or 8, help="Parallel JIT workers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--problems",
|
||||
default="2,8,1024,128",
|
||||
help="Problem sizes: batch,nhead,seqlen,hdim",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-verify", action="store_true", help="Skip CPU reference verification"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best", action="store_true", help="Show best kernel per problem"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
type=str,
|
||||
default=None,
|
||||
help="CSV output path (default: <build-dir>/results.csv). Use --no-csv to disable.",
|
||||
)
|
||||
parser.add_argument("--no-csv", action="store_true", help="Disable CSV output")
|
||||
parser.add_argument("--json", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--log",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to detailed log file (compilation status, failures, timings)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--build-dir",
|
||||
type=str,
|
||||
default=str(Path(__file__).resolve().parent / "build"),
|
||||
help="JIT build output directory",
|
||||
)
|
||||
parser.add_argument("--clean", action="store_true")
|
||||
parser.add_argument("--compile-only", action="store_true")
|
||||
parser.add_argument(
|
||||
"--filter",
|
||||
dest="filter_expr",
|
||||
default="",
|
||||
help='Python expr per config, e.g. "c.hdim_q == 128"',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter-file", default="", help="Path to .py with filter_config(c) -> bool"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tiles",
|
||||
choices=["rules", "exhaustive"],
|
||||
default="rules",
|
||||
help="Tile enumeration mode: 'rules' (default) uses constraint-based generation; "
|
||||
"'exhaustive' brute-forces ALL compilable tiles (like the oracle)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-splits",
|
||||
default="1,2,4,8",
|
||||
help="Comma-separated num_splits values to sweep for splitkv (default: 1,2,4,8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--isolate",
|
||||
action="store_true",
|
||||
help="Run each kernel in a subprocess to survive GPU faults (slower but fault-tolerant)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpus",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated GPU IDs to use for parallel benchmarking (e.g. '0,1,2,3'). "
|
||||
"Implies --isolate. Each GPU runs one kernel at a time.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --gpus implies --isolate
|
||||
if args.gpus:
|
||||
args.isolate = True
|
||||
gpu_ids = [int(x) for x in args.gpus.split(",")] if args.gpus else [0]
|
||||
|
||||
problems = parse_problems(args.problems)
|
||||
num_splits_list = [int(x) for x in args.num_splits.split(",")]
|
||||
build_dir = Path(args.build_dir).resolve()
|
||||
|
||||
if args.clean and build_dir.exists():
|
||||
print(f" Cleaning {build_dir} ...")
|
||||
shutil.rmtree(build_dir)
|
||||
|
||||
build_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Phase 0: Expand configs
|
||||
all_configs = []
|
||||
restrict_hdims = sorted({(p.hdim_q, p.hdim_v) for p in problems})
|
||||
if args.tiles == "exhaustive":
|
||||
# Exhaustive mode: all tiles (no constraint filter) × full feature cross-product.
|
||||
# JSON config is optional — if provided, its trait_config scopes the sweep.
|
||||
cfg_path = args.configs[0] if args.configs else None
|
||||
all_configs = expand_sweep(
|
||||
cfg_path,
|
||||
args.arch,
|
||||
0,
|
||||
mode="exhaustive",
|
||||
restrict_hdims=restrict_hdims,
|
||||
)
|
||||
print(
|
||||
f" Exhaustive: {len(all_configs)} total combos (all tiles × all features)"
|
||||
)
|
||||
else:
|
||||
if not args.configs:
|
||||
parser.error(
|
||||
"Config JSON(s) required for rules mode. Use --tiles exhaustive to run without."
|
||||
)
|
||||
for cfg_path in args.configs:
|
||||
configs = expand_sweep(
|
||||
cfg_path,
|
||||
args.arch,
|
||||
0,
|
||||
mode="rules",
|
||||
restrict_hdims=restrict_hdims,
|
||||
)
|
||||
all_configs.extend(configs)
|
||||
print(f" {cfg_path}: {len(configs)} kernel configs")
|
||||
|
||||
if args.filter_expr or args.filter_file:
|
||||
before = len(all_configs)
|
||||
all_configs = apply_filter(all_configs, args.filter_expr, args.filter_file)
|
||||
print(f" Filter: {before} -> {len(all_configs)} configs")
|
||||
|
||||
# Remove standalone combine configs -- they are auto-paired during JIT
|
||||
all_configs = [c for c in all_configs if c.family != "fwd_splitkv_combine"]
|
||||
|
||||
print(f"\n{'=' * 70}")
|
||||
print("FMHA Tile Engine Benchmark")
|
||||
print(f"{'=' * 70}")
|
||||
print(f" Arch: {args.arch}")
|
||||
print(f" Kernels: {len(all_configs)}")
|
||||
print(f" Problems: {len(problems)}")
|
||||
print(f" Workers: {args.workers}")
|
||||
print(f" Build: {build_dir}")
|
||||
|
||||
# Phase 1: Pipelined JIT via the dispatcher
|
||||
print(
|
||||
f"\n--- Phase 1: JIT compile ({len(all_configs)} kernels,"
|
||||
f" {args.workers} workers) ---"
|
||||
)
|
||||
jit_t0 = time.perf_counter()
|
||||
|
||||
def _progress(stage, done, total):
|
||||
elapsed = time.perf_counter() - jit_t0
|
||||
pct = done * 100 // total
|
||||
print(
|
||||
f"\r [{stage}] {done}/{total} ({pct}%) - {elapsed:.0f}s",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
if done == total:
|
||||
print()
|
||||
|
||||
setups = setup_multiple_fmha_dispatchers(
|
||||
all_configs,
|
||||
output_dir=build_dir,
|
||||
verbose=True,
|
||||
max_workers=args.workers,
|
||||
progress_callback=_progress,
|
||||
)
|
||||
|
||||
jit_time = time.perf_counter() - jit_t0
|
||||
built = sum(1 for s in setups if s.success)
|
||||
failed = len(all_configs) - built
|
||||
print(f"\n Built {built}/{len(all_configs)} in {jit_time:.0f}s ({failed} failed)")
|
||||
|
||||
# Load runners for successfully compiled kernels
|
||||
for setup in setups:
|
||||
if setup.success and setup.library_path and setup.runner is None:
|
||||
try:
|
||||
setup.runner = FmhaRunner.from_library(setup.library_path, args.arch)
|
||||
except Exception as e:
|
||||
print(f" Warning: Failed to load runner: {e}")
|
||||
setup.success = False
|
||||
|
||||
if args.compile_only:
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f" Compile-only mode. {built}/{len(all_configs)} kernels compiled.")
|
||||
if failed > 0:
|
||||
print("\n Failed kernels:")
|
||||
for cfg, s in zip(all_configs, setups):
|
||||
if not s.success:
|
||||
err = (s.error or "unknown")[:80]
|
||||
print(f" {cfg.name}: {err}")
|
||||
if args.tiles == "exhaustive":
|
||||
# Oracle-style analysis: find tiles missed by rules vs compilable
|
||||
from fmha.instance_gen import validate_tile, FmhaTileConfig # noqa: E402
|
||||
|
||||
missed = []
|
||||
for cfg, s in zip(all_configs, setups):
|
||||
if s.success:
|
||||
tile = FmhaTileConfig(
|
||||
bm0=cfg.tile_m0,
|
||||
bn0=cfg.tile_n0,
|
||||
bk0=cfg.tile_k0,
|
||||
bn1=cfg.tile_n1,
|
||||
bk1=cfg.tile_k1,
|
||||
bk0max=cfg.tile_k0max,
|
||||
rm0=cfg.wave_m0,
|
||||
rn0=1,
|
||||
rk0=1,
|
||||
rm1=cfg.wave_m1,
|
||||
rn1=1,
|
||||
rk1=1,
|
||||
wm0=cfg.warp_m0,
|
||||
wn0=cfg.warp_n0,
|
||||
wk0=cfg.warp_k0,
|
||||
wm1=cfg.warp_m1,
|
||||
wn1=cfg.warp_n1,
|
||||
wk1=cfg.warp_k1,
|
||||
)
|
||||
if not validate_tile(
|
||||
tile,
|
||||
args.arch,
|
||||
cfg.data_type,
|
||||
cfg.hdim_q,
|
||||
cfg.hdim_v,
|
||||
cfg.pipeline,
|
||||
):
|
||||
missed.append(cfg)
|
||||
if missed:
|
||||
print(
|
||||
f"\n MISSED by rules ({len(missed)} tiles compile but rules reject):"
|
||||
)
|
||||
seen = set()
|
||||
for cfg in missed:
|
||||
key = (cfg.tile_m0, cfg.tile_n0, cfg.tile_k0)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
print(
|
||||
f" ({cfg.tile_m0:>3}, {cfg.tile_n0:>3}, {cfg.tile_k0:>3})"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\n Rules are COMPLETE — all compilable tiles are generated by rules."
|
||||
)
|
||||
print(f"{'=' * 70}")
|
||||
return
|
||||
|
||||
# Phase 2: Benchmark
|
||||
print(f"\n--- Phase 2: Benchmark ({built} kernels x {len(problems)} problems) ---")
|
||||
|
||||
dtype_map = {
|
||||
"fp16": np.float16,
|
||||
"bf16": np.float32,
|
||||
"fp32": np.float32,
|
||||
"fp8": np.float16,
|
||||
"fp8bf16": np.float16,
|
||||
"fp8fp32": np.float16,
|
||||
"bf8": np.float16,
|
||||
"mxfp8": np.float16,
|
||||
"mxfp4": np.float16,
|
||||
}
|
||||
# Tolerance per dtype: (atol, rtol)
|
||||
_DTYPE_TOL = {
|
||||
"fp16": (1e-3, 1e-3),
|
||||
"bf16": (1e-2, 1e-2),
|
||||
"fp32": (1e-5, 1e-5),
|
||||
"fp8": (16.0, 0.0),
|
||||
"fp8bf16": (16.0, 0.0),
|
||||
"fp8fp32": (16.0, 0.0),
|
||||
"bf8": (16.0, 0.0),
|
||||
"mxfp8": (16.0, 0.0),
|
||||
"mxfp4": (32.0, 0.0),
|
||||
}
|
||||
np.random.seed(42)
|
||||
all_results = []
|
||||
bench_t0 = time.perf_counter()
|
||||
|
||||
for prob_idx, prob in enumerate(problems):
|
||||
first_dtype = all_configs[0].data_type if all_configs else "fp16"
|
||||
first_mask = all_configs[0].mask if all_configs else "no"
|
||||
np_dtype = dtype_map.get(first_dtype, np.float16)
|
||||
dtype_tol = _DTYPE_TOL.get(first_dtype, (1e-2, 1e-2))
|
||||
# Use uniform [0, 1] like CK example (default 'uf' mode) -- produces
|
||||
# peaked softmax distributions that actually test kernel correctness.
|
||||
# randn*0.1 makes softmax nearly uniform for large hdim, hiding bugs.
|
||||
Q = np.random.uniform(0, 1, prob.q_shape()).astype(np_dtype)
|
||||
K = np.random.uniform(0, 1, prob.k_shape()).astype(np_dtype)
|
||||
V = np.random.uniform(0, 1, prob.v_shape()).astype(np_dtype)
|
||||
|
||||
_MASK_INT = {"no": 0, "top_left": 1, "bottom_right": 2, "generic": 3}
|
||||
first_mask_int = _MASK_INT.get(first_mask, 0)
|
||||
|
||||
ref = None
|
||||
if not args.no_verify:
|
||||
# For bf16: truncate inputs to bf16 precision before computing reference,
|
||||
# so reference sees the SAME data the kernel sees (after bf16 encoding).
|
||||
if first_dtype == "bf16":
|
||||
from fmha_utils import _float32_to_bf16, _bf16_to_float32
|
||||
|
||||
Q_ref = _bf16_to_float32(_float32_to_bf16(Q.astype(np.float32)))
|
||||
K_ref = _bf16_to_float32(_float32_to_bf16(K.astype(np.float32)))
|
||||
V_ref = _bf16_to_float32(_float32_to_bf16(V.astype(np.float32)))
|
||||
else:
|
||||
Q_ref = Q.astype(np.float32)
|
||||
K_ref = K.astype(np.float32)
|
||||
V_ref = V.astype(np.float32)
|
||||
ref = cpu_attention_fwd(
|
||||
Q_ref,
|
||||
K_ref,
|
||||
V_ref,
|
||||
prob.scale,
|
||||
mask_type=first_mask_int,
|
||||
)
|
||||
|
||||
h_str = (
|
||||
f"H={prob.nhead_q}"
|
||||
if prob.nhead_q == prob.nhead_k
|
||||
else f"Hq={prob.nhead_q} Hk={prob.nhead_k}"
|
||||
)
|
||||
s_str = (
|
||||
f"S={prob.seqlen_q}"
|
||||
if prob.seqlen_q == prob.seqlen_k
|
||||
else f"Sq={prob.seqlen_q} Sk={prob.seqlen_k}"
|
||||
)
|
||||
prob_str = f"B={prob.batch} {h_str} {s_str} D={prob.hdim_q}"
|
||||
print(f"\n Problem [{prob_idx}]: {prob_str}")
|
||||
print(
|
||||
f" {'Kernel':<105} {'Time(ms)':>10} {'TFLOPS':>10}"
|
||||
f" {'MaxErr':>10} {'Status':>6}"
|
||||
)
|
||||
print(f" {'-' * 145}")
|
||||
|
||||
_BIAS_INT = {"no": 0, "bias": 1, "alibi": 2}
|
||||
|
||||
# Build list of (config, setup, run_kwargs, ns) jobs for benchmarking
|
||||
bench_jobs = []
|
||||
for config, setup in zip(all_configs, setups):
|
||||
if not setup.success:
|
||||
continue
|
||||
if not args.isolate and setup.runner is None:
|
||||
continue
|
||||
if config.hdim_q != prob.hdim_q or config.hdim_v != prob.hdim_v:
|
||||
continue
|
||||
|
||||
mask_int = _MASK_INT.get(config.mask, 0)
|
||||
is_causal = config.mask in ("top_left", "bottom_right")
|
||||
is_group = config.mode == "group"
|
||||
|
||||
_FAMILY_TO_API = {
|
||||
"fwd_splitkv": "splitkv",
|
||||
"fwd_pagedkv": "pagedkv",
|
||||
"fwd_appendkv": "appendkv",
|
||||
}
|
||||
api_family = _FAMILY_TO_API.get(config.family, config.family)
|
||||
splits_to_try = num_splits_list if api_family == "splitkv" else [0]
|
||||
|
||||
for ns in splits_to_try:
|
||||
run_kwargs = dict(
|
||||
mask_type=mask_int,
|
||||
bias_type=_BIAS_INT.get(config.bias, 0),
|
||||
has_lse=int(config.lse),
|
||||
has_dropout=int(config.dropout),
|
||||
has_logits=int(config.logits),
|
||||
has_sink=int(config.sink),
|
||||
data_type=config.data_type,
|
||||
is_group_mode=int(is_group),
|
||||
is_v_rowmajor=int(config.vlayout == "r"),
|
||||
api_family=api_family,
|
||||
window_left=-1,
|
||||
window_right=0 if is_causal else -1,
|
||||
)
|
||||
if api_family == "splitkv":
|
||||
run_kwargs["num_splits"] = ns
|
||||
bench_jobs.append(
|
||||
(config, setup, run_kwargs, ns, api_family, is_causal)
|
||||
)
|
||||
|
||||
if args.isolate and len(gpu_ids) > 1:
|
||||
# ---- Multi-GPU parallel isolated execution ----
|
||||
import tempfile
|
||||
|
||||
# Save input data once, shared by all subprocesses
|
||||
shared_data_dir = tempfile.mkdtemp(prefix="fmha_shared_")
|
||||
np.save(os.path.join(shared_data_dir, "Q.npy"), Q)
|
||||
np.save(os.path.join(shared_data_dir, "K.npy"), K)
|
||||
np.save(os.path.join(shared_data_dir, "V.npy"), V)
|
||||
|
||||
def _run_one(job, gpu_id):
|
||||
config, setup, run_kwargs, ns, api_family, is_causal = job
|
||||
# Per-job output dir (unique per subprocess)
|
||||
job_dir = tempfile.mkdtemp(prefix=f"fmha_gpu{gpu_id}_")
|
||||
# Symlink shared inputs instead of copying
|
||||
for fname in ("Q.npy", "K.npy", "V.npy"):
|
||||
os.symlink(
|
||||
os.path.join(shared_data_dir, fname),
|
||||
os.path.join(job_dir, fname),
|
||||
)
|
||||
time_ms, output, err = _run_kernel_isolated(
|
||||
setup.library_path, args.arch, prob, run_kwargs, job_dir, gpu_id
|
||||
)
|
||||
shutil.rmtree(job_dir, ignore_errors=True)
|
||||
return (config, time_ms, output, err, ns, api_family, is_causal, gpu_id)
|
||||
|
||||
print(f" Running {len(bench_jobs)} kernels across {len(gpu_ids)} GPUs ...")
|
||||
for _, result in run_parallel_on_gpus(bench_jobs, gpu_ids, _run_one):
|
||||
config, time_ms, output, err, ns, api_family, is_causal, gpu_id = result
|
||||
if err:
|
||||
splits_tag = f" [ns={ns}]" if api_family == "splitkv" else ""
|
||||
print(
|
||||
f" {config.name}{splits_tag:<105} {'---':>10} {'---':>10} {'---':>10} GPU{gpu_id} {err[:15]}"
|
||||
)
|
||||
continue
|
||||
|
||||
r, line = _compute_result(
|
||||
config,
|
||||
prob,
|
||||
time_ms,
|
||||
output,
|
||||
ref,
|
||||
is_causal,
|
||||
ns,
|
||||
api_family,
|
||||
dtype_tol,
|
||||
gpu_id,
|
||||
)
|
||||
print(line)
|
||||
all_results.append(r)
|
||||
|
||||
shutil.rmtree(shared_data_dir, ignore_errors=True)
|
||||
|
||||
else:
|
||||
# ---- Sequential execution (in-process or single-GPU isolated) ----
|
||||
for config, setup, run_kwargs, ns, api_family, is_causal in bench_jobs:
|
||||
time_ms = None
|
||||
output = None
|
||||
if args.isolate:
|
||||
import tempfile
|
||||
|
||||
data_dir = tempfile.mkdtemp(prefix="fmha_run_")
|
||||
np.save(os.path.join(data_dir, "Q.npy"), Q)
|
||||
np.save(os.path.join(data_dir, "K.npy"), K)
|
||||
np.save(os.path.join(data_dir, "V.npy"), V)
|
||||
time_ms, output, err = _run_kernel_isolated(
|
||||
setup.library_path,
|
||||
args.arch,
|
||||
prob,
|
||||
run_kwargs,
|
||||
data_dir,
|
||||
gpu_ids[0],
|
||||
)
|
||||
shutil.rmtree(data_dir, ignore_errors=True)
|
||||
if err:
|
||||
print(
|
||||
f" {config.name:<105} {'---':>10} {'---':>10} {'---':>10} {err[:20]:>6}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
result = setup.runner.run(Q, K, V, prob, **run_kwargs)
|
||||
if not result.success:
|
||||
continue
|
||||
time_ms = result.time_ms
|
||||
output = result.output
|
||||
|
||||
r, line = _compute_result(
|
||||
config,
|
||||
prob,
|
||||
time_ms,
|
||||
output,
|
||||
ref,
|
||||
is_causal,
|
||||
ns,
|
||||
api_family,
|
||||
dtype_tol,
|
||||
)
|
||||
print(line)
|
||||
all_results.append(r)
|
||||
|
||||
bench_time = time.perf_counter() - bench_t0
|
||||
|
||||
# Cleanup
|
||||
for setup in setups:
|
||||
if setup.success and setup.runner:
|
||||
try:
|
||||
setup.runner.cleanup()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Report
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f" JIT: {jit_time:.0f}s ({built} kernels)")
|
||||
print(f" Benchmark: {bench_time:.1f}s")
|
||||
print(f" Results: {len(all_results)} measurements")
|
||||
|
||||
if all_results:
|
||||
from collections import defaultdict
|
||||
|
||||
by_problem = defaultdict(list)
|
||||
for r in all_results:
|
||||
key = json.dumps(r["problem"], sort_keys=True)
|
||||
by_problem[key].append(r)
|
||||
|
||||
print("\n Best kernel per problem:")
|
||||
for key, results in by_problem.items():
|
||||
best = max(results, key=lambda x: x["tflops"])
|
||||
prob = json.loads(key)
|
||||
ns_tag = f" [ns={best['num_splits']}]" if best.get("num_splits") else ""
|
||||
h_str = (
|
||||
f"H={prob['nhead_q']}"
|
||||
if prob["nhead_q"] == prob["nhead_k"]
|
||||
else f"Hq={prob['nhead_q']} Hk={prob['nhead_k']}"
|
||||
)
|
||||
s_str = (
|
||||
f"S={prob['seqlen_q']}"
|
||||
if prob["seqlen_q"] == prob["seqlen_k"]
|
||||
else f"Sq={prob['seqlen_q']} Sk={prob['seqlen_k']}"
|
||||
)
|
||||
print(
|
||||
f" B={prob['batch']} {h_str}"
|
||||
f" {s_str} D={prob['hdim_q']}"
|
||||
f" -> {best['kernel']}{ns_tag}"
|
||||
f" ({best['tflops']:.2f} TFLOPS, {best['latency_ms']:.3f} ms)"
|
||||
)
|
||||
|
||||
# CSV output: default to <build-dir>/results.csv; merge with existing file
|
||||
# keeping the faster result (higher tflops) for duplicate kernel+problem keys.
|
||||
_CSV_FIELDS = [
|
||||
"kernel",
|
||||
"dtype",
|
||||
"pipeline",
|
||||
"mode",
|
||||
"mask",
|
||||
"bias",
|
||||
"hdim_q",
|
||||
"hdim_v",
|
||||
"tile_m0",
|
||||
"tile_n0",
|
||||
"tile_k0",
|
||||
"tile_n1",
|
||||
"tile_k1",
|
||||
"tile_k0max",
|
||||
"warp_m0",
|
||||
"warp_n0",
|
||||
"warp_k0",
|
||||
"block_per_cu",
|
||||
"num_splits",
|
||||
"batch",
|
||||
"nhead_q",
|
||||
"nhead_k",
|
||||
"seqlen_q",
|
||||
"seqlen_k",
|
||||
"latency_ms",
|
||||
"tflops",
|
||||
"max_err",
|
||||
"status",
|
||||
]
|
||||
csv_path = args.csv if args.csv else str(build_dir / "results.csv")
|
||||
if not args.no_csv and all_results:
|
||||
# Build map of new results keyed by (kernel, problem-tuple)
|
||||
def _csv_key(row):
|
||||
p = row["problem"] if "problem" in row else row
|
||||
return (
|
||||
row["kernel"],
|
||||
row.get("num_splits", 0),
|
||||
p.get("batch"),
|
||||
p.get("nhead_q"),
|
||||
p.get("nhead_k"),
|
||||
p.get("seqlen_q"),
|
||||
p.get("seqlen_k"),
|
||||
p.get("hdim_q"),
|
||||
p.get("hdim_v"),
|
||||
)
|
||||
|
||||
# Load existing CSV if present
|
||||
existing = {}
|
||||
if os.path.isfile(csv_path):
|
||||
with open(csv_path, "r", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
# Convert numeric fields back from strings
|
||||
for k in row:
|
||||
if k in ("latency_ms", "tflops", "max_err"):
|
||||
try:
|
||||
row[k] = float(row[k])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
elif k in (
|
||||
"hdim_q",
|
||||
"hdim_v",
|
||||
"tile_m0",
|
||||
"tile_n0",
|
||||
"tile_k0",
|
||||
"tile_n1",
|
||||
"tile_k1",
|
||||
"tile_k0max",
|
||||
"warp_m0",
|
||||
"warp_n0",
|
||||
"warp_k0",
|
||||
"block_per_cu",
|
||||
"num_splits",
|
||||
"batch",
|
||||
"nhead_q",
|
||||
"nhead_k",
|
||||
"seqlen_q",
|
||||
"seqlen_k",
|
||||
):
|
||||
try:
|
||||
row[k] = int(row[k])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
key = _csv_key(row)
|
||||
existing[key] = row
|
||||
|
||||
# Merge new results — keep whichever is faster
|
||||
for r in all_results:
|
||||
row = {**r, **r["problem"]}
|
||||
del row["problem"]
|
||||
key = _csv_key(r)
|
||||
prev = existing.get(key)
|
||||
if prev is None or float(row.get("tflops", 0)) > float(
|
||||
prev.get("tflops", 0)
|
||||
):
|
||||
existing[key] = row
|
||||
|
||||
# Write merged + sorted CSV
|
||||
merged = sorted(
|
||||
existing.values(), key=lambda x: float(x.get("tflops", 0)), reverse=True
|
||||
)
|
||||
with open(csv_path, "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=_CSV_FIELDS, extrasaction="ignore")
|
||||
writer.writeheader()
|
||||
for row in merged:
|
||||
writer.writerow(row)
|
||||
print(f"\n CSV: {csv_path} ({len(merged)} rows, sorted by tflops)")
|
||||
|
||||
if args.json:
|
||||
report = {
|
||||
"metadata": {
|
||||
"arch": args.arch,
|
||||
"jit_time_s": jit_time,
|
||||
"bench_time_s": bench_time,
|
||||
"num_kernels": len(all_configs),
|
||||
"num_built": built,
|
||||
"num_problems": len(problems),
|
||||
},
|
||||
"results": all_results,
|
||||
}
|
||||
with open(args.json, "w") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
print(f" JSON: {args.json}")
|
||||
|
||||
if args.log:
|
||||
from datetime import datetime
|
||||
|
||||
with open(args.log, "w") as lf:
|
||||
lf.write(f"FMHA Benchmark Log - {datetime.now().isoformat()}\n")
|
||||
lf.write(f"{'=' * 80}\n\n")
|
||||
lf.write(f"Command: {' '.join(sys.argv)}\n")
|
||||
lf.write(f"Arch: {args.arch}\n")
|
||||
lf.write(f"Tiles mode: {args.tiles}\n")
|
||||
lf.write(f"Workers: {args.workers}\n")
|
||||
lf.write(f"Build dir: {build_dir}\n")
|
||||
lf.write(f"Total configs: {len(all_configs)}\n")
|
||||
lf.write(f"Built: {built}\n")
|
||||
lf.write(f"Failed: {failed}\n")
|
||||
lf.write(f"JIT time: {jit_time:.1f}s\n")
|
||||
lf.write(f"Bench time: {bench_time:.1f}s\n")
|
||||
lf.write(f"Problems: {[str(p) for p in problems]}\n\n")
|
||||
|
||||
# All configs attempted
|
||||
lf.write(f"{'=' * 80}\n")
|
||||
lf.write(f"ALL CONFIGS ({len(all_configs)})\n")
|
||||
lf.write(f"{'=' * 80}\n\n")
|
||||
for i, (cfg, setup) in enumerate(zip(all_configs, setups)):
|
||||
status = "OK" if setup.success else "FAILED"
|
||||
lf.write(f"[{i:4d}] {status:6s} {cfg.name}\n")
|
||||
lf.write(
|
||||
f" tile=({cfg.tile_m0},{cfg.tile_n0},{cfg.tile_k0},{cfg.tile_n1},{cfg.tile_k1},{cfg.tile_k0max})"
|
||||
f" warp=({cfg.warp_m0},{cfg.warp_n0},{cfg.warp_k0})"
|
||||
f" bpc={cfg.block_per_cu}\n"
|
||||
)
|
||||
if not setup.success and setup.error:
|
||||
lf.write(f" error: {setup.error}\n")
|
||||
lf.write("\n")
|
||||
|
||||
# Failed configs summary
|
||||
lf.write(f"\n{'=' * 80}\n")
|
||||
lf.write(f"FAILED CONFIGS ({failed})\n")
|
||||
lf.write(f"{'=' * 80}\n\n")
|
||||
for cfg, setup in zip(all_configs, setups):
|
||||
if not setup.success:
|
||||
lf.write(f" {cfg.name}\n")
|
||||
if setup.error:
|
||||
lf.write(f" {setup.error}\n")
|
||||
|
||||
# Benchmark results
|
||||
if all_results:
|
||||
lf.write(f"\n{'=' * 80}\n")
|
||||
lf.write(f"BENCHMARK RESULTS ({len(all_results)} measurements)\n")
|
||||
lf.write(f"{'=' * 80}\n\n")
|
||||
sorted_results = sorted(all_results, key=lambda x: -x["tflops"])
|
||||
for r in sorted_results:
|
||||
p = r["problem"]
|
||||
lf.write(
|
||||
f" {r['tflops']:8.2f} TFLOPS {r['latency_ms']:8.3f} ms"
|
||||
f" B={p['batch']} H={p['nhead_q']} S={p['seqlen_q']} D={p['hdim_q']}"
|
||||
f" {r['kernel']}\n"
|
||||
)
|
||||
|
||||
print(f" Log: {args.log}")
|
||||
|
||||
print(f"{'=' * 70}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
689
tile_engine/ops/fmha/fmha_full_benchmark.py
Normal file
689
tile_engine/ops/fmha/fmha_full_benchmark.py
Normal file
@@ -0,0 +1,689 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Full FMHA benchmark sweep.
|
||||
|
||||
JIT-compiles FMHA kernels, then for EACH test shape finds all matching
|
||||
kernels and benchmarks them, streaming results incrementally to CSV/JSON.
|
||||
|
||||
Results are printed live per-shape with the best kernel highlighted.
|
||||
TFLOPS and latency come directly from CK's HIP event timing.
|
||||
|
||||
Usage:
|
||||
# Full sweep
|
||||
python fmha_full_benchmark.py --workers 256
|
||||
|
||||
# Quick end-to-end test
|
||||
python fmha_full_benchmark.py --category smoke --variant fwd --max-kernels 10 --workers 4
|
||||
|
||||
# Filter to h128 fp16
|
||||
python fmha_full_benchmark.py --filter "c.hdim_q == 128 and c.data_type == 'fp16'"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
import numpy as np
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_DISPATCHER_ROOT = _THIS_DIR.parents[2] / "dispatcher"
|
||||
sys.path.insert(0, str(_DISPATCHER_ROOT / "python"))
|
||||
sys.path.insert(0, str(_DISPATCHER_ROOT / "codegen"))
|
||||
sys.path.insert(0, str(_THIS_DIR))
|
||||
|
||||
from fmha_utils import ( # noqa: E402
|
||||
detect_gpu_arch,
|
||||
setup_multiple_fmha_dispatchers,
|
||||
)
|
||||
from fmha.instance_gen import expand_sweep, apply_filter # noqa: E402
|
||||
|
||||
YAML_PATH = _THIS_DIR / "ck_fmha_testing_matrix.yaml"
|
||||
|
||||
VARIANT_CONFIGS = {
|
||||
"fwd": "configs/receipt0_fwd.json",
|
||||
"splitkv": "configs/splitkv.json",
|
||||
"pagedkv": "configs/pagedkv.json",
|
||||
"appendkv": "configs/appendkv.json",
|
||||
"batch_prefill": "configs/batch_prefill.json",
|
||||
"bwd": "configs/bwd.json",
|
||||
}
|
||||
|
||||
# Variant -> YAML section mapping. KV-cache variants use forward_tests shapes.
|
||||
VARIANT_YAML_SECTIONS = {
|
||||
"fwd": ["forward_tests"],
|
||||
"splitkv": ["forward_tests"],
|
||||
"pagedkv": ["forward_tests"],
|
||||
"appendkv": ["forward_tests"],
|
||||
"batch_prefill": ["forward_tests"],
|
||||
"bwd": ["backward_tests"],
|
||||
}
|
||||
|
||||
DTYPE_CK = {"fp16": "fp16", "bf16": "bf16", "fp8bf16": "fp8bf16", "fp8fp32": "fp8fp32"}
|
||||
DTYPE_NP = {
|
||||
"fp16": np.float16,
|
||||
"bf16": np.float16,
|
||||
"fp32": np.float32,
|
||||
"fp8bf16": np.float16,
|
||||
"fp8fp32": np.float16,
|
||||
}
|
||||
ELEM_BYTES = {"fp16": 2, "bf16": 2, "fp32": 4, "fp8bf16": 1, "fp8fp32": 1}
|
||||
|
||||
MASK_INT = {"no": 0, "top_left": 1, "generic": 3}
|
||||
BIAS_INT = {"no": 0, "bias": 1, "alibi": 2}
|
||||
KV_LAYOUT_INT = {"vectorized": 0, "linear": 1}
|
||||
KV_LOOKUP_INT = {"vllm": 0, "sglang": 1}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestShape:
|
||||
name: str
|
||||
category: str
|
||||
variant: str
|
||||
batch: int
|
||||
seqlen_q: int
|
||||
seqlen_k: int
|
||||
nhead_q: int
|
||||
nhead_k: int
|
||||
hdim_q: int
|
||||
hdim_v: int
|
||||
dtype: str
|
||||
mask: str = "no_mask"
|
||||
bias: str = "none"
|
||||
dropout: float = 0.0
|
||||
lse: bool = False
|
||||
|
||||
|
||||
def parse_yaml(
|
||||
yaml_path: Path, category: str = "smoke", sections: Optional[List[str]] = None
|
||||
) -> List[TestShape]:
|
||||
with open(yaml_path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
shapes = []
|
||||
cats = ["smoke"]
|
||||
if category in ("full", "nightly"):
|
||||
cats.append("full")
|
||||
if category == "nightly":
|
||||
cats.append("nightly")
|
||||
|
||||
section_variant_map = [("forward_tests", "fwd"), ("backward_tests", "bwd")]
|
||||
if sections:
|
||||
section_variant_map = [(s, v) for s, v in section_variant_map if s in sections]
|
||||
|
||||
for section, variant in section_variant_map:
|
||||
if section not in data:
|
||||
continue
|
||||
for cat in cats:
|
||||
for test in data[section].get(cat, []):
|
||||
for combo in itertools.product(
|
||||
test.get("batch", [1]),
|
||||
test.get("seqlen_q", [1024]),
|
||||
test.get("seqlen_k", [1024]),
|
||||
test.get("nhead_q", [16]),
|
||||
test.get("nhead_k", [16]),
|
||||
test.get("hdim_q", [128]),
|
||||
test.get("hdim_v", [128]),
|
||||
test.get("dtype", ["fp16"]),
|
||||
test.get("mask", ["no_mask"]),
|
||||
test.get("bias", ["none"]),
|
||||
test.get("dropout", [0.0]),
|
||||
test.get("lse", [False]),
|
||||
):
|
||||
b, sq, sk, hq, hk, dq, dv, dt, m, bi, dr, ls = combo
|
||||
shapes.append(
|
||||
TestShape(
|
||||
test["name"],
|
||||
cat,
|
||||
variant,
|
||||
b,
|
||||
sq,
|
||||
sk,
|
||||
hq,
|
||||
hk,
|
||||
dq,
|
||||
dv,
|
||||
dt,
|
||||
mask=m,
|
||||
bias=bi,
|
||||
dropout=dr,
|
||||
lse=ls,
|
||||
)
|
||||
)
|
||||
return shapes
|
||||
|
||||
|
||||
def bandwidth_gb_s(shape: TestShape, latency_ms: float) -> float:
|
||||
if latency_ms <= 0:
|
||||
return 0.0
|
||||
eb = ELEM_BYTES.get(shape.dtype, 2)
|
||||
total = (
|
||||
shape.batch
|
||||
* (
|
||||
shape.nhead_q * shape.seqlen_q * shape.hdim_q
|
||||
+ shape.nhead_k * shape.seqlen_k * shape.hdim_q
|
||||
+ shape.nhead_k * shape.seqlen_k * shape.hdim_v
|
||||
+ shape.nhead_q * shape.seqlen_q * shape.hdim_v
|
||||
)
|
||||
* eb
|
||||
)
|
||||
return total / (latency_ms * 1e6)
|
||||
|
||||
|
||||
FAMILY_TO_API = {
|
||||
"fwd": "fwd",
|
||||
"fwd_splitkv": "splitkv",
|
||||
"fwd_splitkv_combine": "splitkv",
|
||||
"fwd_pagedkv": "pagedkv",
|
||||
"fwd_appendkv": "appendkv",
|
||||
"batch_prefill": "batch_prefill",
|
||||
"bwd_dot_do_o": "bwd",
|
||||
"bwd_dq_dk_dv": "bwd",
|
||||
"bwd_convert_dq": "bwd",
|
||||
}
|
||||
|
||||
|
||||
def _config_to_serializable(config, so_path: str) -> dict:
|
||||
"""Convert FmhaKernelConfig + so_path to a picklable dict for subprocess."""
|
||||
return {
|
||||
"so_path": so_path,
|
||||
"api_family": FAMILY_TO_API.get(config.family, "fwd"),
|
||||
"data_type": config.data_type,
|
||||
"kernel": config.name,
|
||||
"family": config.family,
|
||||
"mode": config.mode,
|
||||
"pipeline": config.pipeline,
|
||||
"tile_m0": config.tile_m0,
|
||||
"tile_n0": config.tile_n0,
|
||||
"tile_k0": config.tile_k0,
|
||||
"tile_n1": config.tile_n1,
|
||||
"tile_k1": config.tile_k1,
|
||||
"tile_k0max": config.tile_k0max,
|
||||
"pad_s": config.pad_s,
|
||||
"pad_sk": config.pad_sk,
|
||||
"pad_d": config.pad_d,
|
||||
"pad_dv": config.pad_dv,
|
||||
"mask": config.mask,
|
||||
"bias": config.bias,
|
||||
"lse": config.lse,
|
||||
"dropout": config.dropout,
|
||||
"logits": config.logits,
|
||||
"sink": config.sink,
|
||||
"skip": config.skip_min_seqlen_q,
|
||||
"qscale": config.qscale,
|
||||
"paged_kv": config.paged_kv,
|
||||
"rope": config.rope,
|
||||
"deterministic": config.deterministic,
|
||||
"dbias": config.dbias,
|
||||
"mask_int": MASK_INT.get(config.mask, 0),
|
||||
"bias_int": BIAS_INT.get(config.bias, 0),
|
||||
"has_lse": int(config.lse),
|
||||
"has_dropout": int(config.dropout not in (False, 0, "no", "False")),
|
||||
"has_logits": int(config.logits),
|
||||
"has_sink": int(config.sink),
|
||||
"has_skip": int(config.skip_min_seqlen_q),
|
||||
"has_dbias": int(getattr(config, "dbias", False)),
|
||||
"is_store_randval": int(getattr(config, "store_randval", False)),
|
||||
"page_size": getattr(config, "page_size", 16),
|
||||
"kv_layout": KV_LAYOUT_INT.get(
|
||||
getattr(config, "kv_memory_layout", "vectorized"), 0
|
||||
),
|
||||
"kv_lookup": KV_LOOKUP_INT.get(getattr(config, "kv_lookup_table", "sglang"), 1),
|
||||
}
|
||||
|
||||
|
||||
def _shape_to_dict(shape: TestShape) -> dict:
|
||||
return {
|
||||
"name": shape.name,
|
||||
"category": shape.category,
|
||||
"variant": shape.variant,
|
||||
"batch": shape.batch,
|
||||
"seqlen_q": shape.seqlen_q,
|
||||
"seqlen_k": shape.seqlen_k,
|
||||
"nhead_q": shape.nhead_q,
|
||||
"nhead_k": shape.nhead_k,
|
||||
"hdim_q": shape.hdim_q,
|
||||
"hdim_v": shape.hdim_v,
|
||||
"dtype": shape.dtype,
|
||||
"mask": shape.mask,
|
||||
"bias": shape.bias,
|
||||
"dropout": shape.dropout,
|
||||
"lse": shape.lse,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(description="Full FMHA Benchmark Sweep")
|
||||
p.add_argument("--arch", default=detect_gpu_arch())
|
||||
p.add_argument("--category", default="smoke", choices=["smoke", "full", "nightly"])
|
||||
p.add_argument("--variant", default="all")
|
||||
p.add_argument("--workers", type=int, default=8)
|
||||
p.add_argument("--build-dir", default="/tmp/fmha_full_bench")
|
||||
p.add_argument("--filter", dest="filter_expr", default="")
|
||||
p.add_argument("--filter-file", default="")
|
||||
p.add_argument("--csv", default="fmha_sweep_results.csv")
|
||||
p.add_argument("--json", default="fmha_sweep_results.json")
|
||||
p.add_argument("--compile-only", action="store_true")
|
||||
p.add_argument("--max-kernels", type=int, default=0)
|
||||
p.add_argument(
|
||||
"--shape-timeout",
|
||||
type=int,
|
||||
default=600,
|
||||
help="Per-shape timeout in seconds (0=none)",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
build_dir = Path(args.build_dir)
|
||||
build_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
variants = list(VARIANT_CONFIGS.keys()) if args.variant == "all" else [args.variant]
|
||||
|
||||
# ---- Phase 1: Parse shapes ----
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Phase 1: Parse test shapes")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
all_shapes: List[TestShape] = []
|
||||
for variant in variants:
|
||||
sections = VARIANT_YAML_SECTIONS.get(variant, ["forward_tests"])
|
||||
vshapes = parse_yaml(YAML_PATH, args.category, sections=sections)
|
||||
for s in vshapes:
|
||||
s.variant = variant
|
||||
all_shapes.extend(vshapes)
|
||||
|
||||
print(f" Category: {args.category}")
|
||||
print(f" Variants: {variants}")
|
||||
print(f" Total shapes: {len(all_shapes)}")
|
||||
|
||||
# ---- Phase 2: Compile ----
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Phase 2: Compile kernels")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# kernel_index: (hdim_q, hdim_v, dtype, variant) -> list of (so_path, cfg_dict)
|
||||
kernel_index: Dict[tuple, List[tuple]] = {}
|
||||
|
||||
from concurrent.futures import ProcessPoolExecutor as _PPE
|
||||
|
||||
_compile_pool = _PPE(max_workers=args.workers)
|
||||
BATCH_SIZE = 200
|
||||
|
||||
for variant in variants:
|
||||
cfg_path = str(_THIS_DIR / VARIANT_CONFIGS[variant])
|
||||
if not Path(cfg_path).exists():
|
||||
continue
|
||||
configs = expand_sweep(cfg_path, args.arch)
|
||||
if args.filter_expr or args.filter_file:
|
||||
configs = apply_filter(configs, args.filter_expr, args.filter_file)
|
||||
if args.max_kernels > 0:
|
||||
configs = configs[: args.max_kernels]
|
||||
if not configs:
|
||||
continue
|
||||
|
||||
n_batches = (len(configs) + BATCH_SIZE - 1) // BATCH_SIZE
|
||||
print(
|
||||
f"\n {variant}: {len(configs)} configs, {args.workers} workers, {n_batches} batches..."
|
||||
)
|
||||
t0 = time.perf_counter()
|
||||
setups = []
|
||||
total_ok = 0
|
||||
for bi in range(n_batches):
|
||||
batch_cfgs = configs[bi * BATCH_SIZE : (bi + 1) * BATCH_SIZE]
|
||||
batch_setups = setup_multiple_fmha_dispatchers(
|
||||
batch_cfgs,
|
||||
output_dir=build_dir,
|
||||
max_workers=args.workers,
|
||||
executor=_compile_pool,
|
||||
)
|
||||
batch_ok = sum(1 for s in batch_setups if s.success)
|
||||
batch_n = len(batch_cfgs)
|
||||
total_ok += batch_ok
|
||||
setups.extend(zip(batch_cfgs, batch_setups))
|
||||
del batch_setups, batch_cfgs
|
||||
print(
|
||||
f" Batch {bi + 1}/{n_batches}: {batch_ok}/{batch_n} "
|
||||
f"(total {total_ok}, {time.perf_counter() - t0:.0f}s)",
|
||||
flush=True,
|
||||
)
|
||||
ok = total_ok
|
||||
print(f" Built {ok}/{len(configs)} in {time.perf_counter() - t0:.0f}s")
|
||||
|
||||
for config, setup in setups:
|
||||
if not setup.success:
|
||||
continue
|
||||
so_path = getattr(setup, "library_path", "") or ""
|
||||
if not so_path:
|
||||
candidate = build_dir / f"libdispatcher_fmha_{config.name}.so"
|
||||
if candidate.exists():
|
||||
so_path = str(candidate)
|
||||
if not so_path:
|
||||
continue
|
||||
cfg_dict = _config_to_serializable(config, so_path)
|
||||
key = (config.hdim_q, config.hdim_v, config.data_type, variant, config.mode)
|
||||
kernel_index.setdefault(key, []).append((so_path, cfg_dict))
|
||||
|
||||
_compile_pool.shutdown(wait=True)
|
||||
del _compile_pool
|
||||
|
||||
total_built = sum(len(v) for v in kernel_index.values())
|
||||
print(f"\n Total compiled: {total_built}")
|
||||
print(f" Unique (hdim,dtype,variant) groups: {len(kernel_index)}")
|
||||
|
||||
if args.compile_only:
|
||||
print(f"\n Compile-only. {total_built} kernels ready.")
|
||||
return
|
||||
|
||||
# ---- Phase 3: Benchmark (serial, one subprocess per kernel) ----
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Phase 3: Benchmark (one subprocess per kernel, serial GPU)")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
csv_path = Path(args.csv) if os.path.isabs(args.csv) else _THIS_DIR / args.csv
|
||||
csv_fields = [
|
||||
"problem_name",
|
||||
"batch",
|
||||
"seqlen_q",
|
||||
"seqlen_k",
|
||||
"nhead_q",
|
||||
"nhead_k",
|
||||
"hdim_q",
|
||||
"hdim_v",
|
||||
"dtype",
|
||||
"kernel",
|
||||
"family",
|
||||
"mode",
|
||||
"pipeline",
|
||||
"tile_m0",
|
||||
"tile_n0",
|
||||
"tile_k0",
|
||||
"tile_n1",
|
||||
"tile_k1",
|
||||
"tile_k0max",
|
||||
"pad_s",
|
||||
"pad_sk",
|
||||
"pad_d",
|
||||
"pad_dv",
|
||||
"mask",
|
||||
"bias",
|
||||
"lse",
|
||||
"dropout",
|
||||
"logits",
|
||||
"sink",
|
||||
"skip",
|
||||
"qscale",
|
||||
"paged_kv",
|
||||
"rope",
|
||||
"deterministic",
|
||||
"dbias",
|
||||
"latency_ms",
|
||||
"tflops",
|
||||
"bandwidth_gb_s",
|
||||
]
|
||||
|
||||
# Resume: load already-completed measurements
|
||||
completed: set = set()
|
||||
if csv_path.exists() and csv_path.stat().st_size > 0:
|
||||
with open(csv_path, newline="") as f:
|
||||
for row in csv.DictReader(f):
|
||||
completed.add(
|
||||
(
|
||||
row.get("kernel", ""),
|
||||
row.get("problem_name", ""),
|
||||
str(row.get("batch", "")),
|
||||
str(row.get("seqlen_q", "")),
|
||||
row.get("dtype", ""),
|
||||
)
|
||||
)
|
||||
csv_file = open(csv_path, "a", newline="")
|
||||
writer = csv.DictWriter(csv_file, fieldnames=csv_fields)
|
||||
print(f" Resuming: {len(completed)} measurements already in CSV")
|
||||
else:
|
||||
csv_file = open(csv_path, "w", newline="")
|
||||
writer = csv.DictWriter(csv_file, fieldnames=csv_fields)
|
||||
writer.writeheader()
|
||||
|
||||
# Pre-filter: match shapes to kernels by (hdim, dtype, variant, mode).
|
||||
# YAML shapes are batch-mode only. Group-mode kernels need seqstart arrays
|
||||
# which batch shapes don't provide -- they all GPU fault.
|
||||
runnable = []
|
||||
for shape in all_shapes:
|
||||
ck_dtype = DTYPE_CK.get(shape.dtype, shape.dtype)
|
||||
key = (shape.hdim_q, shape.hdim_v, ck_dtype, shape.variant, "batch")
|
||||
entries = kernel_index.get(key, [])
|
||||
if entries:
|
||||
runnable.append((shape, entries))
|
||||
|
||||
# Flatten to work items, skip already-completed
|
||||
def _resume_key(cfg, shape):
|
||||
return (
|
||||
cfg.get("kernel", ""),
|
||||
shape.name,
|
||||
str(shape.batch),
|
||||
str(shape.seqlen_q),
|
||||
shape.dtype,
|
||||
)
|
||||
|
||||
work_items = []
|
||||
skipped = 0
|
||||
for shape, kernel_entries in runnable:
|
||||
for so_path, cfg in kernel_entries:
|
||||
if _resume_key(cfg, shape) in completed:
|
||||
skipped += 1
|
||||
else:
|
||||
work_items.append((shape, so_path, cfg))
|
||||
|
||||
total_work = len(work_items) + skipped
|
||||
total_measurements = len(completed)
|
||||
total_gpu_faults = 0
|
||||
bench_t0 = time.perf_counter()
|
||||
BENCH_BATCH = 50
|
||||
|
||||
worker_path = _THIS_DIR / "run_one_kernel.py"
|
||||
worker_env = os.environ.copy()
|
||||
worker_env["FMHA_PYPATH_1"] = str(_DISPATCHER_ROOT / "python")
|
||||
worker_env["FMHA_PYPATH_2"] = str(_DISPATCHER_ROOT / "codegen")
|
||||
|
||||
CFG_KEYS = [
|
||||
"kernel",
|
||||
"family",
|
||||
"mode",
|
||||
"pipeline",
|
||||
"tile_m0",
|
||||
"tile_n0",
|
||||
"tile_k0",
|
||||
"tile_n1",
|
||||
"tile_k1",
|
||||
"tile_k0max",
|
||||
"pad_s",
|
||||
"pad_sk",
|
||||
"pad_d",
|
||||
"pad_dv",
|
||||
"mask",
|
||||
"bias",
|
||||
"lse",
|
||||
"dropout",
|
||||
"logits",
|
||||
"sink",
|
||||
"skip",
|
||||
"qscale",
|
||||
"paged_kv",
|
||||
"rope",
|
||||
"deterministic",
|
||||
"dbias",
|
||||
]
|
||||
|
||||
print(f" Runnable shapes: {len(runnable)}")
|
||||
print(f" Total kernel x shape pairs: {total_work}")
|
||||
print(f" Already completed: {skipped}")
|
||||
print(f" Pending: {len(work_items)}")
|
||||
print(f" Batch size: {BENCH_BATCH} (retry individually on fault)")
|
||||
print()
|
||||
|
||||
def _run_subprocess(payload_bytes, timeout=10):
|
||||
proc = subprocess.Popen(
|
||||
[sys.executable, str(worker_path)],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL,
|
||||
env=worker_env,
|
||||
)
|
||||
timed_out = False
|
||||
stdout_bytes = b""
|
||||
try:
|
||||
stdout_bytes, _ = proc.communicate(input=payload_bytes, timeout=timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.communicate()
|
||||
timed_out = True
|
||||
finally:
|
||||
pid = proc.pid
|
||||
if proc.poll() is None:
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
for pipe in [proc.stdin, proc.stdout, proc.stderr]:
|
||||
if pipe and not pipe.closed:
|
||||
pipe.close()
|
||||
gpucore = _THIS_DIR / f"gpucore.{pid}"
|
||||
if gpucore.exists():
|
||||
gpucore.unlink(missing_ok=True)
|
||||
rc = -1 if timed_out else proc.returncode
|
||||
return stdout_bytes, rc
|
||||
|
||||
def _record_result(r, shape, cfg, shape_dict):
|
||||
nonlocal total_measurements
|
||||
lat_ms, tflops = r["ms"], r["tflops"]
|
||||
bw = bandwidth_gb_s(shape, lat_ms)
|
||||
row = {
|
||||
"problem_name": shape.name,
|
||||
"batch": shape.batch,
|
||||
"seqlen_q": shape.seqlen_q,
|
||||
"seqlen_k": shape.seqlen_k,
|
||||
"nhead_q": shape.nhead_q,
|
||||
"nhead_k": shape.nhead_k,
|
||||
"hdim_q": shape.hdim_q,
|
||||
"hdim_v": shape.hdim_v,
|
||||
"dtype": shape.dtype,
|
||||
}
|
||||
for k in CFG_KEYS:
|
||||
row[k] = cfg.get(k, "")
|
||||
row["latency_ms"] = round(lat_ms, 4)
|
||||
row["tflops"] = round(tflops, 2)
|
||||
row["bandwidth_gb_s"] = round(bw, 2)
|
||||
writer.writerow(row)
|
||||
csv_file.flush()
|
||||
total_measurements += 1
|
||||
return tflops, lat_ms
|
||||
|
||||
# Process in batches
|
||||
n_batches = (len(work_items) + BENCH_BATCH - 1) // BENCH_BATCH
|
||||
processed = 0
|
||||
for bi in range(n_batches):
|
||||
batch = work_items[bi * BENCH_BATCH : (bi + 1) * BENCH_BATCH]
|
||||
|
||||
items = []
|
||||
for shape, so_path, cfg in batch:
|
||||
cfg["so_path"] = so_path
|
||||
items.append(
|
||||
{"so_path": so_path, "shape": _shape_to_dict(shape), "cfg": cfg}
|
||||
)
|
||||
|
||||
batch_timeout = len(batch) * 2 + 5
|
||||
payload = json.dumps({"items": items}).encode()
|
||||
stdout_bytes, rc = _run_subprocess(payload, timeout=batch_timeout)
|
||||
|
||||
if rc == 0:
|
||||
batch_ok = 0
|
||||
for line in stdout_bytes.decode().strip().split("\n"):
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
idx = r.get("idx", -1)
|
||||
if not r.get("ok") or idx < 0 or idx >= len(batch):
|
||||
continue
|
||||
shape, so_path, cfg = batch[idx]
|
||||
_record_result(r, shape, cfg, items[idx]["shape"])
|
||||
batch_ok += 1
|
||||
processed += len(batch)
|
||||
print(
|
||||
f" [batch {bi + 1}/{n_batches}] {batch_ok}/{len(batch)} ok "
|
||||
f"({processed}/{len(work_items)} done, {total_measurements} total)",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
# Collect partial results flushed before the fault
|
||||
partial_done = set()
|
||||
for line in stdout_bytes.decode().strip().split("\n"):
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
idx = r.get("idx", -1)
|
||||
if r.get("ok") and 0 <= idx < len(batch):
|
||||
shape, so_path, cfg = batch[idx]
|
||||
_record_result(r, shape, cfg, items[idx]["shape"])
|
||||
partial_done.add(idx)
|
||||
|
||||
# Retry the rest one by one
|
||||
retry = [(i, b) for i, b in enumerate(batch) if i not in partial_done]
|
||||
print(
|
||||
f" [batch {bi + 1}/{n_batches}] FAULT after {len(partial_done)}/{len(batch)} ok, "
|
||||
f"retrying {len(retry)} individually...",
|
||||
flush=True,
|
||||
)
|
||||
for idx, (shape, so_path, cfg) in retry:
|
||||
cfg["so_path"] = so_path
|
||||
p = json.dumps(
|
||||
{"so_path": so_path, "shape": items[idx]["shape"], "cfg": cfg}
|
||||
).encode()
|
||||
out, single_rc = _run_subprocess(p, timeout=10)
|
||||
if single_rc != 0:
|
||||
total_gpu_faults += 1
|
||||
continue
|
||||
try:
|
||||
r = json.loads(out.decode().strip().split("\n")[0])
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
if r.get("ok"):
|
||||
tflops, lat_ms = _record_result(r, shape, cfg, items[idx]["shape"])
|
||||
print(
|
||||
f" {tflops:>7.1f} TFLOPS {lat_ms:.4f}ms "
|
||||
f"{cfg.get('kernel', '?')[:45]} | {shape.name}",
|
||||
flush=True,
|
||||
)
|
||||
processed += len(batch)
|
||||
print(f" ({processed}/{len(work_items)} done)", flush=True)
|
||||
|
||||
csv_file.close()
|
||||
bench_time = time.perf_counter() - bench_t0
|
||||
|
||||
# ---- Phase 4: Summary ----
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Results")
|
||||
print(f"{'=' * 80}")
|
||||
print(f" Total work items: {total_work}")
|
||||
print(f" Skipped (resumed): {skipped}")
|
||||
print(f" Measurements: {total_measurements}")
|
||||
print(f" GPU faults: {total_gpu_faults}")
|
||||
print(f" Benchmark time: {bench_time:.1f}s")
|
||||
print(f" CSV: {csv_path}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
175
tile_engine/ops/fmha/run_full_sweep.py
Normal file
175
tile_engine/ops/fmha/run_full_sweep.py
Normal file
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Full FMHA benchmark sweep, organized by variant and dtype.
|
||||
|
||||
Compiles all kernels per variant (shared build dir for caching),
|
||||
benchmarks against all smoke shapes, then splits results into:
|
||||
|
||||
<output_dir>/
|
||||
fwd/fp16/results.csv
|
||||
fwd/bf16/results.csv
|
||||
splitkv/fp16/results.csv
|
||||
...
|
||||
bwd_dot_do_o/fp16/results.csv
|
||||
bwd_dq_dk_dv/fp16/results.csv
|
||||
bwd_convert_dq/fp16/results.csv
|
||||
|
||||
Usage:
|
||||
python run_full_sweep.py --workers 256
|
||||
python run_full_sweep.py --workers 256 --category full --output /tmp/fmha_sweep
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
|
||||
VARIANTS = ["fwd", "splitkv", "pagedkv", "appendkv", "batch_prefill", "bwd"]
|
||||
|
||||
BWD_FAMILIES = ["bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"]
|
||||
|
||||
|
||||
def run_variant(variant, category, workers, build_dir, raw_csv, shape_timeout=600):
|
||||
"""Run fmha_full_benchmark.py for one variant, return path to raw CSV."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(_THIS_DIR / "fmha_full_benchmark.py"),
|
||||
"--category",
|
||||
category,
|
||||
"--variant",
|
||||
variant,
|
||||
"--workers",
|
||||
str(workers),
|
||||
"--build-dir",
|
||||
str(build_dir),
|
||||
"--csv",
|
||||
str(raw_csv),
|
||||
"--json",
|
||||
str(raw_csv.with_suffix(".json")),
|
||||
"--shape-timeout",
|
||||
str(shape_timeout),
|
||||
]
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f" Variant: {variant}")
|
||||
print(f" Command: {' '.join(cmd)}")
|
||||
print(f"{'=' * 80}", flush=True)
|
||||
|
||||
env = os.environ.copy()
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
proc = subprocess.run(cmd, env=env)
|
||||
return proc.returncode
|
||||
|
||||
|
||||
def split_csv(raw_csv, output_dir):
|
||||
"""Split a raw CSV into per-family per-dtype subdirectories."""
|
||||
if not raw_csv.exists():
|
||||
return {}
|
||||
|
||||
counts = defaultdict(int)
|
||||
writers = {}
|
||||
files = {}
|
||||
|
||||
with open(raw_csv, newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
fieldnames = reader.fieldnames
|
||||
|
||||
for row in reader:
|
||||
family = row.get("family", "unknown")
|
||||
dtype = row.get("dtype", "unknown")
|
||||
key = (family, dtype)
|
||||
|
||||
if key not in writers:
|
||||
d = output_dir / family / dtype
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
fh = open(d / "results.csv", "w", newline="")
|
||||
w = csv.DictWriter(fh, fieldnames=fieldnames)
|
||||
w.writeheader()
|
||||
writers[key] = w
|
||||
files[key] = fh
|
||||
|
||||
writers[key].writerow(row)
|
||||
counts[key] += 1
|
||||
|
||||
for fh in files.values():
|
||||
fh.close()
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(
|
||||
description="Full FMHA Sweep (organized by variant/dtype)"
|
||||
)
|
||||
p.add_argument("--workers", type=int, default=256)
|
||||
p.add_argument("--category", default="smoke", choices=["smoke", "full", "nightly"])
|
||||
p.add_argument("--output", default="/tmp/fmha_sweep")
|
||||
p.add_argument("--build-dir", default="/tmp/fmha_sweep_build")
|
||||
p.add_argument(
|
||||
"--variants",
|
||||
nargs="+",
|
||||
default=VARIANTS,
|
||||
choices=VARIANTS,
|
||||
help="Which variants to run",
|
||||
)
|
||||
p.add_argument(
|
||||
"--shape-timeout", type=int, default=600, help="Per-shape timeout in seconds"
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
output_dir = Path(args.output)
|
||||
build_dir = Path(args.build_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
build_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
grand_total = defaultdict(int)
|
||||
|
||||
for variant in args.variants:
|
||||
raw_csv = output_dir / f"_raw_{variant}.csv"
|
||||
rc = run_variant(
|
||||
variant, args.category, args.workers, build_dir, raw_csv, args.shape_timeout
|
||||
)
|
||||
if rc != 0:
|
||||
print(f"\n WARNING: {variant} exited with code {rc}", flush=True)
|
||||
|
||||
counts = split_csv(raw_csv, output_dir)
|
||||
for key, n in counts.items():
|
||||
grand_total[key] += n
|
||||
family, dtype = key
|
||||
print(f" {family}/{dtype}: {n} measurements")
|
||||
|
||||
elapsed = time.perf_counter() - t0
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SWEEP COMPLETE")
|
||||
print(f"{'=' * 80}")
|
||||
print(f" Total time: {elapsed / 60:.1f} min")
|
||||
print(f" Output dir: {output_dir}")
|
||||
print()
|
||||
print(f" {'Family':<25} {'Dtype':<10} {'Measurements':>12}")
|
||||
print(f" {'-' * 25} {'-' * 10} {'-' * 12}")
|
||||
total = 0
|
||||
for (family, dtype), n in sorted(grand_total.items()):
|
||||
print(f" {family:<25} {dtype:<10} {n:>12,}")
|
||||
total += n
|
||||
print(f" {'-' * 25} {'-' * 10} {'-' * 12}")
|
||||
print(f" {'TOTAL':<25} {'':<10} {total:>12,}")
|
||||
|
||||
print("\n Directory structure:")
|
||||
for d in sorted(output_dir.rglob("results.csv")):
|
||||
rel = d.relative_to(output_dir)
|
||||
print(f" {rel}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
128
tile_engine/ops/fmha/run_one_kernel.py
Normal file
128
tile_engine/ops/fmha/run_one_kernel.py
Normal file
@@ -0,0 +1,128 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Run FMHA kernel(s) on GPU and report timing.
|
||||
|
||||
Single mode: stdin = {"so_path": ..., "shape": {...}, "cfg": {...}}
|
||||
Batch mode: stdin = {"items": [{"so_path": ..., "shape": {...}, "cfg": {...}}, ...]}
|
||||
|
||||
Each result prints one JSON line to stdout (flushed immediately):
|
||||
{"idx": 0, "ok": true, "ms": 0.123, "tflops": 456.7}
|
||||
{"idx": 1, "ok": false}
|
||||
|
||||
Flushing per-line lets the parent recover partial results if a later
|
||||
kernel causes a GPU fault that kills this process.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
for p in [os.environ.get("FMHA_PYPATH_1", ""), os.environ.get("FMHA_PYPATH_2", "")]:
|
||||
if p and p not in sys.path:
|
||||
sys.path.insert(0, p)
|
||||
|
||||
from fmha_utils import FmhaProblem, FmhaRunner # noqa: E402
|
||||
|
||||
DTYPE_NP = {
|
||||
"fp16": np.float16,
|
||||
"bf16": np.float16,
|
||||
"fp32": np.float32,
|
||||
"fp8bf16": np.float16,
|
||||
"fp8fp32": np.float16,
|
||||
}
|
||||
|
||||
|
||||
def _run_one(idx, so_path, s, cfg):
|
||||
prob = FmhaProblem(
|
||||
batch=s["batch"],
|
||||
nhead_q=s["nhead_q"],
|
||||
nhead_k=s["nhead_k"],
|
||||
seqlen_q=s["seqlen_q"],
|
||||
seqlen_k=s["seqlen_k"],
|
||||
hdim_q=s["hdim_q"],
|
||||
hdim_v=s["hdim_v"],
|
||||
)
|
||||
dt = DTYPE_NP.get(s.get("dtype", "fp16"), np.float16)
|
||||
np.random.seed(42)
|
||||
q = (np.random.randn(*prob.q_shape()) * 0.1).astype(dt)
|
||||
k = (np.random.randn(*prob.k_shape()) * 0.1).astype(dt)
|
||||
v = (np.random.randn(*prob.v_shape()) * 0.1).astype(dt)
|
||||
|
||||
runner = FmhaRunner.from_library(so_path)
|
||||
api = cfg.get("api_family", "fwd")
|
||||
|
||||
if api == "bwd":
|
||||
out_buf = (
|
||||
np.random.randn(s["batch"], s["nhead_q"], s["seqlen_q"], s["hdim_v"]) * 0.1
|
||||
).astype(dt)
|
||||
lse = np.random.randn(s["batch"], s["nhead_q"], s["seqlen_q"]).astype(
|
||||
np.float32
|
||||
)
|
||||
d_out = (np.random.randn(*out_buf.shape) * 0.1).astype(dt)
|
||||
result = runner.run_bwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out_buf,
|
||||
lse,
|
||||
d_out,
|
||||
prob,
|
||||
data_type=cfg.get("data_type", "fp16"),
|
||||
mask_type=cfg.get("mask_int", 0),
|
||||
bias_type=cfg.get("bias_int", 0),
|
||||
has_dropout=cfg.get("has_dropout", 0),
|
||||
has_dbias=cfg.get("has_dbias", 0),
|
||||
is_deterministic=cfg.get("deterministic", 0),
|
||||
is_group_mode=cfg.get("mode", "batch") == "group",
|
||||
is_store_randval=cfg.get("is_store_randval", 0),
|
||||
tile_n0=cfg.get("tile_n0", 128),
|
||||
)
|
||||
else:
|
||||
result = runner.run(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
prob,
|
||||
mask_type=cfg.get("mask_int", 0),
|
||||
bias_type=cfg.get("bias_int", 0),
|
||||
has_lse=cfg.get("has_lse", 0),
|
||||
has_dropout=cfg.get("has_dropout", 0),
|
||||
has_logits=cfg.get("has_logits", 0),
|
||||
has_sink=cfg.get("has_sink", 0),
|
||||
has_skip=cfg.get("has_skip", 0),
|
||||
api_family=api,
|
||||
data_type=cfg.get("data_type", "fp16"),
|
||||
page_size=cfg.get("page_size", 16),
|
||||
kv_layout=cfg.get("kv_layout", 0),
|
||||
kv_lookup=cfg.get("kv_lookup", 1),
|
||||
is_group_mode=cfg.get("mode", "batch") == "group",
|
||||
)
|
||||
|
||||
if result.success:
|
||||
print(
|
||||
json.dumps(
|
||||
{"idx": idx, "ok": True, "ms": result.time_ms, "tflops": result.tflops}
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print(json.dumps({"idx": idx, "ok": False}), flush=True)
|
||||
|
||||
|
||||
def main():
|
||||
d = json.loads(sys.stdin.buffer.read())
|
||||
|
||||
if "items" in d:
|
||||
for i, item in enumerate(d["items"]):
|
||||
_run_one(i, item["so_path"], item["shape"], item["cfg"])
|
||||
else:
|
||||
_run_one(0, d["cfg"]["so_path"], d["shape"], d["cfg"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user