Files
composable_kernel/example/ck_tile
Yi DING 6a9c03f692 [rocm-libraries] ROCm/rocm-libraries#7450 (commit 402dbad)
[CK_TILE] Use Persistent Scheduling for FMHA BWD Group Deterministic (#7450)

## Motivation

FMHA BWD group-mode deterministic currently uses a non-persistent
scheduler: each `(batch, head, K-row)` work-item is launched as its own
block, with no work-stealing across CUs. On uneven workloads (varlen,
GQA, many heads with
few K-rows) this leaves CUs idle and forces a larger dq_acc workspace
than necessary.

This PR ports the persistent + deterministic scheduling already used in
batch mode to group mode: a fixed-grid kernel that pre-computes per-CU
work ranges on the host and uses sparse dq_acc slot indexing so multiple
K-rows handled
by the same CU share one accumulator slot via intra-CU atomic adds.

Stacked on #7331; merge that first.

## Technical Details

Single file changed: `ops/fmha/kernel/fmha_bwd_kernel.hpp`.

A new `kUsePersistent` path is added to the group-mode deterministic
kernel, mirroring the batch-mode persistent scheduler. The host
pre-computes a fixed per-CU partition of the total `(batch, head,
K-row)` work and packs it into
`cu_states[]` so the GPU consumes it in a single launch. Host
preparation happens in four steps:

1. Build per-batch `seqstart` prefix sums.
2. Fill per-batch `(sq_w, nc)` with a placeholder `nsplits` (bumped in
step 3).
3. Two-pointer scan over CUs to fill `cu_states[c]` (`isplit`,
`head_start`, `c_start`, `w_lo`, `w_hi`), accumulating `nsplits[b]` as
`max(cs->isplit + 1)`.
4. Compute compact per-batch dq_acc offsets from the finalized
`nsplits`.

`isplit` is the sparse dq_acc slot index — one CU's multi-K-row writes
share slot `ceil(wc_start / denom)`, enabling intra-CU atomic
accumulation instead of one slot per K-row.

`denom = max(sq_w, target_w)`, splitting two regimes:

- `target_w >= sq_w` (large work): `denom = target_w`, intra-CU atomic
optimization engaged.
- `target_w < sq_w` (sub-K-row sharding, multiple CUs sharing one
K-row): `denom = sq_w` collapses to per-K-row indexing (`= c_start`),
keeping `isplit ∈ [0, nc-1]` and matching the `nsplits_max =
ceil(s_k/kN0) = nc` upper bound that #7331's
`GetWorkspaceDeviceSizeUpperBound` assumes for group+det.

`isplit` is additionally clamped to `nc-1` to absorb empty CUs
(rounded-up `wc_start` past the last K-row); they don't write dq_acc on
GPU so the slot value is harmless.

`nsplits[b]` is accumulated dynamically in step 3 rather than via a
closed form so it tightly matches the actual sparse slots used; step 4
(offsets) follows step 3 since offsets now depend on the dynamic
`nsplits`.

Group mode also allows batches with `seqlen_q == 0`. The persistent
scheduler skips them on the dQ path (no work) but dK/dV are still
zero-filled.

## Test Plan

Built `tile_example_fmha_bwd` with receipt 5 (fp16, no-bias, no-dropout,
`dpad == dvpad`, group + batch) on gfx950 (MI355X).

- 8-case smoke (shapes that exercise the sub-K-row regime).
- 44-case sweep covering: mask 0/1/2, GQA, var seqlen, `d != d_v`,
extreme
  small seqlen / `nc=1`, CU >> work, huge batch, batch-mode regression.
- 12-case perf comparison vs the non-persistent baseline (warmup=10,
  repeat=50).

## Test Result

- All 8 + 44 cases `valid:y`.
- Perf: ±5% noise, average -0.4% across the 12 cases — neutral.
- Batch-mode deterministic / non-deterministic regression unchanged.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-26 10:01:54 +08:00
..

CK Tile Example Suite

This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.


What is CK Tile?

CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.


Example Index

Example Operation Description
01_fmha Fused Multi-Head Attention Tile-based FMHA with masking, quantization, and epilogue fusion
02_layernorm2d LayerNorm2D Blockwise layer normalization with fusion and quantization
03_gemm GEMM Matrix multiplication with tilewise parallelism
04_img2col im2col Image-to-column transformation for GEMM-based convolution
05_reduce Reduction Tilewise sum, max, mean reductions
06_permute Permute Generic tensor permutation (up to rank-8)
09_topk_softmax TopK-Softmax Rowwise softmax and top-k selection for MoE gating
10_rmsnorm2d RMSNorm2D Root mean square normalization for LLMs
11_add_rmsnorm2d_rdquant Add + RMSNorm2D + RDQuant Fused add, RMSNorm, and rowwise dynamic quantization
12_smoothquant SmoothQuant Per-channel scaling and quantization for int8 inference
13_moe_sorting MoE Sorting Token-to-expert rearrangement for MoE dispatch
14_moe_smoothquant MoE-SmoothQuant Expert-dependent quantization fused with top-k selection
15_fused_moe Fused MoE End-to-end fused MoE block: sorting, group-GEMM, activation, weighting
16_batched_gemm Batched GEMM Parallel computation of multiple GEMMs
17_grouped_gemm Grouped GEMM Multiple independent GEMMs with different shapes
18_flatmm FLATMM Flattened matrix multiplication for packed layouts
19_gemm_multi_d Multi-D GEMM GEMM with multiple side inputs (bias, residual, etc.)
35_batched_transpose Batched Transpose NCHW <-> NHWC and other layout conversions
36_copy Copy Minimal example for tile-based memory movement
37_transpose Block Transpose High-performance tiled transpose for large tensors

Technical Highlights


How to Build & Run

mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j

Each example produces its own executable in build/bin/.


Learning and Extending


References


Back to Composable Kernel Examples