Files
composable_kernel/example/ck_tile
Chao 1d1be9e3de [rocm-libraries] ROCm/rocm-libraries#6529 (commit 93a6097)
[CK_TILE] Enable V3 persistent kernel dispatch for FMHA
 forward on gfx950 (#6529)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

[CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on
gfx950

## Motivation

Enable the existing V3 persistent kernel path for CK-Tile FMHA forward
on
gfx950 (MI350X/MI355X). The V3 kernel and codegen infrastructure already
exist but are disabled via hardcoded `F_is_v3_enabled=False`.

This change replaces the compile-time gate with a runtime environment
variable
`CK_FMHA_ENABLE_V3=1` (disabled by default, opt-in). When enabled:
- **Prefill** workloads (seqlen_q > 1) dispatch to V3 persistent
pipeline
- **Decode** workloads (seqlen_q == 1) always use V2 (memory-bound,
better suited)

The V3 persistent kernel uses grid-stride scheduling, XCD-interleave
tile
assignment for L2 locality, LPT reversal for causal masks, and gfx950
async
buffer loads.

## Technical Details

Single file: `example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py`
- Add `#include <cstdlib>` and `<string>` for `std::getenv`
- Replace `{F_is_v3_enabled}` template parameter with runtime env var
check
- Add `seqlen_q > 1` guard (decode always uses V2)
- Remove `.format()` call in `write_fwd_api()`

## Dependencies

Depends on https://github.com/ROCm/rocm-libraries/pull/6501 — builds on
XCD-interleave and LPT scheduling infrastructure.

## Test Plan

- GPU validation on MI300X (gfx942, ROCm 6.4.1):
- Command: `./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096 -d=128
-prec=bf16 -v=1 -warmup=1 -repeat=3`
- GPU validation on MI350X (gfx950, ROCm 7.0):
- Command (V2): `./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096
-d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3`
- Command (V3): `CK_FMHA_ENABLE_V3=1 ./build/bin/tile_example_fmha_fwd
-b=2 -h=8 -s=4096 -d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3`
- Command (decode, always V2): `./build/bin/tile_example_fmha_fwd -b=64
-h=32 -h_k=8 -s=1 -s_k=4096 -d=128 -prec=bf16 -mode=group -v=1 -warmup=1
-repeat=3`

## Test Result

Benchmark results (MI350X, gfx950, ROCm 7.0):

| Config | V2 (TFlops) | V3 (TFlops) | Speedup |
|--------|-------------|-------------|---------|
| Non-causal b=2 h=8 hk=2 s=4096 d=128 bf16 | 696.3 | 884.2 | **+27.0%**
|
| Causal b=2 h=8 hk=2 s=4096 d=128 bf16 | 371.3 | 494.9 | **+33.3%** |
| GQA b=2 h=32 hk=8 s=2048 d=128 bf16 | 671.3 | 831.7 | **+23.9%** |
| LLaMA-70B b=1 h=64 hk=8 s=4096 d=128 bf16 | 761.5 | 927.3 | **+21.8%**
|
| Causal GQA b=2 h=32 hk=8 s=2048 d=128 bf16 | 345.4 | 631.9 |
**+82.9%** |
| Long-seq b=1 h=16 s=16384 d=128 bf16 | 797.8 | 969.9 | **+21.6%** |
| Decode b=64 h=32 hk=8 s=1 s_k=4096 bf16 | 1828 GB/s | — (V2 path) |
unaffected |

Benchmark results (MI300X, gfx942, ROCm 6.4.1):

V3 has 0% effect on MI300X — V3 relies on gfx950 async buffer loads and
falls back to the V2 code path on gfx942. No regression on any config.

| Config | TFlops / GB/s | Time (ms) | Delta vs baseline |
|--------|-------------|-----------|-------------------|
| MHA bf16 b=2 h=8 s=4096 d=128 | 342.98 TFlops | 0.401 | +0.1% |
| MHA fp16 b=2 h=8 s=4096 d=128 | 411.18 TFlops | 0.334 | +4.9% |
| Causal MHA bf16 b=2 h=8 s=4096 d=128 | 232.61 TFlops | 0.296 | +2.4% |
| GQA 4:1 bf16 b=2 h=32 hk=8 s=2048 d=128 | 320.07 TFlops | 0.429 |
-1.4% |
| GQA 8:1 bf16 b=2 h=64 hk=8 s=2048 d=128 | 353.91 TFlops | 0.777 |
+1.7% |
| LLaMA-70B prefill b=1 h=64 hk=8 s=4096 d=128 bf16 | 381.53 TFlops |
1.441 | +1.2% |
| Long-seq bf16 b=1 h=16 s=16384 d=128 | 388.61 TFlops | 5.659 | +1.4% |
| Decode b=64 h=32 hk=8 s_k=4096 d=128 bf16 | 693.40 GB/s | 1.550 |
+0.3% |

All validation tests pass (`valid:y`) on both MI300X and MI350X.

Additional validation:
- `CK_FMHA_ENABLE_V3=0` correctly falls back to V2 (default behavior
unchanged)
- `CK_FMHA_ENABLE_V3=1` dispatches to V3 for prefill, V2 for decode
- Validation passes across fp16/bf16, batch/group mode,
causal/non-causal
- No regression on decode path
2026-05-07 16:23:19 +00:00
..

CK Tile Example Suite

This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.


What is CK Tile?

CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.


Example Index

Example Operation Description
01_fmha Fused Multi-Head Attention Tile-based FMHA with masking, quantization, and epilogue fusion
02_layernorm2d LayerNorm2D Blockwise layer normalization with fusion and quantization
03_gemm GEMM Matrix multiplication with tilewise parallelism
04_img2col im2col Image-to-column transformation for GEMM-based convolution
05_reduce Reduction Tilewise sum, max, mean reductions
06_permute Permute Generic tensor permutation (up to rank-8)
09_topk_softmax TopK-Softmax Rowwise softmax and top-k selection for MoE gating
10_rmsnorm2d RMSNorm2D Root mean square normalization for LLMs
11_add_rmsnorm2d_rdquant Add + RMSNorm2D + RDQuant Fused add, RMSNorm, and rowwise dynamic quantization
12_smoothquant SmoothQuant Per-channel scaling and quantization for int8 inference
13_moe_sorting MoE Sorting Token-to-expert rearrangement for MoE dispatch
14_moe_smoothquant MoE-SmoothQuant Expert-dependent quantization fused with top-k selection
15_fused_moe Fused MoE End-to-end fused MoE block: sorting, group-GEMM, activation, weighting
16_batched_gemm Batched GEMM Parallel computation of multiple GEMMs
17_grouped_gemm Grouped GEMM Multiple independent GEMMs with different shapes
18_flatmm FLATMM Flattened matrix multiplication for packed layouts
19_gemm_multi_d Multi-D GEMM GEMM with multiple side inputs (bias, residual, etc.)
35_batched_transpose Batched Transpose NCHW <-> NHWC and other layout conversions
36_copy Copy Minimal example for tile-based memory movement
37_transpose Block Transpose High-performance tiled transpose for large tensors

Technical Highlights


How to Build & Run

mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j

Each example produces its own executable in build/bin/.


Learning and Extending


References


Back to Composable Kernel Examples