mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on gfx950 (#6529) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on gfx950 ## Motivation Enable the existing V3 persistent kernel path for CK-Tile FMHA forward on gfx950 (MI350X/MI355X). The V3 kernel and codegen infrastructure already exist but are disabled via hardcoded `F_is_v3_enabled=False`. This change replaces the compile-time gate with a runtime environment variable `CK_FMHA_ENABLE_V3=1` (disabled by default, opt-in). When enabled: - **Prefill** workloads (seqlen_q > 1) dispatch to V3 persistent pipeline - **Decode** workloads (seqlen_q == 1) always use V2 (memory-bound, better suited) The V3 persistent kernel uses grid-stride scheduling, XCD-interleave tile assignment for L2 locality, LPT reversal for causal masks, and gfx950 async buffer loads. ## Technical Details Single file: `example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py` - Add `#include <cstdlib>` and `<string>` for `std::getenv` - Replace `{F_is_v3_enabled}` template parameter with runtime env var check - Add `seqlen_q > 1` guard (decode always uses V2) - Remove `.format()` call in `write_fwd_api()` ## Dependencies Depends on https://github.com/ROCm/rocm-libraries/pull/6501 — builds on XCD-interleave and LPT scheduling infrastructure. ## Test Plan - GPU validation on MI300X (gfx942, ROCm 6.4.1): - Command: `./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096 -d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3` - GPU validation on MI350X (gfx950, ROCm 7.0): - Command (V2): `./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096 -d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3` - Command (V3): `CK_FMHA_ENABLE_V3=1 ./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096 -d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3` - Command (decode, always V2): `./build/bin/tile_example_fmha_fwd -b=64 -h=32 -h_k=8 -s=1 -s_k=4096 -d=128 -prec=bf16 -mode=group -v=1 -warmup=1 -repeat=3` ## Test Result Benchmark results (MI350X, gfx950, ROCm 7.0): | Config | V2 (TFlops) | V3 (TFlops) | Speedup | |--------|-------------|-------------|---------| | Non-causal b=2 h=8 hk=2 s=4096 d=128 bf16 | 696.3 | 884.2 | **+27.0%** | | Causal b=2 h=8 hk=2 s=4096 d=128 bf16 | 371.3 | 494.9 | **+33.3%** | | GQA b=2 h=32 hk=8 s=2048 d=128 bf16 | 671.3 | 831.7 | **+23.9%** | | LLaMA-70B b=1 h=64 hk=8 s=4096 d=128 bf16 | 761.5 | 927.3 | **+21.8%** | | Causal GQA b=2 h=32 hk=8 s=2048 d=128 bf16 | 345.4 | 631.9 | **+82.9%** | | Long-seq b=1 h=16 s=16384 d=128 bf16 | 797.8 | 969.9 | **+21.6%** | | Decode b=64 h=32 hk=8 s=1 s_k=4096 bf16 | 1828 GB/s | — (V2 path) | unaffected | Benchmark results (MI300X, gfx942, ROCm 6.4.1): V3 has 0% effect on MI300X — V3 relies on gfx950 async buffer loads and falls back to the V2 code path on gfx942. No regression on any config. | Config | TFlops / GB/s | Time (ms) | Delta vs baseline | |--------|-------------|-----------|-------------------| | MHA bf16 b=2 h=8 s=4096 d=128 | 342.98 TFlops | 0.401 | +0.1% | | MHA fp16 b=2 h=8 s=4096 d=128 | 411.18 TFlops | 0.334 | +4.9% | | Causal MHA bf16 b=2 h=8 s=4096 d=128 | 232.61 TFlops | 0.296 | +2.4% | | GQA 4:1 bf16 b=2 h=32 hk=8 s=2048 d=128 | 320.07 TFlops | 0.429 | -1.4% | | GQA 8:1 bf16 b=2 h=64 hk=8 s=2048 d=128 | 353.91 TFlops | 0.777 | +1.7% | | LLaMA-70B prefill b=1 h=64 hk=8 s=4096 d=128 bf16 | 381.53 TFlops | 1.441 | +1.2% | | Long-seq bf16 b=1 h=16 s=16384 d=128 | 388.61 TFlops | 5.659 | +1.4% | | Decode b=64 h=32 hk=8 s_k=4096 d=128 bf16 | 693.40 GB/s | 1.550 | +0.3% | All validation tests pass (`valid:y`) on both MI300X and MI350X. Additional validation: - `CK_FMHA_ENABLE_V3=0` correctly falls back to V2 (default behavior unchanged) - `CK_FMHA_ENABLE_V3=1` dispatches to V3 for prefill, V2 for decode - Validation passes across fp16/bf16, batch/group mode, causal/non-causal - No regression on decode path
CK Tile Example Suite
This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.
What is CK Tile?
CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.
Example Index
| Example | Operation | Description |
|---|---|---|
| 01_fmha | Fused Multi-Head Attention | Tile-based FMHA with masking, quantization, and epilogue fusion |
| 02_layernorm2d | LayerNorm2D | Blockwise layer normalization with fusion and quantization |
| 03_gemm | GEMM | Matrix multiplication with tilewise parallelism |
| 04_img2col | im2col | Image-to-column transformation for GEMM-based convolution |
| 05_reduce | Reduction | Tilewise sum, max, mean reductions |
| 06_permute | Permute | Generic tensor permutation (up to rank-8) |
| 09_topk_softmax | TopK-Softmax | Rowwise softmax and top-k selection for MoE gating |
| 10_rmsnorm2d | RMSNorm2D | Root mean square normalization for LLMs |
| 11_add_rmsnorm2d_rdquant | Add + RMSNorm2D + RDQuant | Fused add, RMSNorm, and rowwise dynamic quantization |
| 12_smoothquant | SmoothQuant | Per-channel scaling and quantization for int8 inference |
| 13_moe_sorting | MoE Sorting | Token-to-expert rearrangement for MoE dispatch |
| 14_moe_smoothquant | MoE-SmoothQuant | Expert-dependent quantization fused with top-k selection |
| 15_fused_moe | Fused MoE | End-to-end fused MoE block: sorting, group-GEMM, activation, weighting |
| 16_batched_gemm | Batched GEMM | Parallel computation of multiple GEMMs |
| 17_grouped_gemm | Grouped GEMM | Multiple independent GEMMs with different shapes |
| 18_flatmm | FLATMM | Flattened matrix multiplication for packed layouts |
| 19_gemm_multi_d | Multi-D GEMM | GEMM with multiple side inputs (bias, residual, etc.) |
| 35_batched_transpose | Batched Transpose | NCHW <-> NHWC and other layout conversions |
| 36_copy | Copy | Minimal example for tile-based memory movement |
| 37_transpose | Block Transpose | High-performance tiled transpose for large tensors |
Technical Highlights
- Tile Distribution: See
include/ck_tile/tile_program/tile_distribution/for mapping tiles to thread blocks. - Block Tile Pipelines: See
include/ck_tile/tile_program/block_tile_pipeline/for memory/computation pipelines. - Policies and Utilities: Many examples use custom policies for tile/block size and memory access.
How to Build & Run
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j
Each example produces its own executable in build/bin/.
Learning and Extending
- Start Simple: Try 03_gemm or 36_copy to learn tile basics.
- Explore Fusion: See 11_add_rmsnorm2d_rdquant, 15_fused_moe, or 14_moe_smoothquant for advanced fusion.
- Experiment: Modify tile sizes, layouts, or pipelines to explore performance and flexibility.