Files
composable_kernel/example/ck_tile
Anton Gorenko 7ecbf82708 [rocm-libraries] ROCm/rocm-libraries#7500 (commit f5cd4fd)
[CK_TILE][FMHA] Optimize long-context decoding on gfx11/12
 (#7500)

## Motivation

Relevant issue: ROCM-22065

FMHA has less-than-optimal performance of long-context decoding (i.e.
when seqlen_q = 1) on gfx11/12.
This PR optimizes the splitkv pipeline and configs for such scenarios.

## Technical Details

Optimizations applied in this PR:
1. use tiles with smaller M0 (16 vs 64), these tiles are used when
seqlen_q <= 16
2. adapt qr_nwarp_sshuffle pipeline for gfx11, it allows to use more
warps even for M0 = 16 (the qr pipeline parallelizes work between warps
in M dim so with M0 = 16 it allows to use only 1 warp)
3. enable kMergeNumHeadGroupsSeqLenQ (an optimization that merges one
group of heads in GQA) for all hdim values, not only 128
4. increase the number of splits (multiply by the number of head groups)
if (3) is used
5. increase the number of splits for RDNAs (`multiProcessorCount` is the
number of WGPs on RDNAs, not CUs, so it should be doubled to have
meaning similar to CDNAs)

Performance on gfx1151:

| Case | develop (GB/s) | This PR (GB/s) |
|:-------|-------:|-------:|
| [fp16\|group\|bshd] b:1, h:32/32, s:1/45056, d:64/64 | 127.58 | 183.11
|
| [fp16\|group\|bhsd] b:1, h:32/32, s:1/45056, d:64/64 | 153.64 | 215.02
|
| [fp16\|group\|bshd] b:1, h:16/8, s:1/77184, d:128/128 | 120.51 |
225.76 |
| [fp16\|group\|bhsd] b:1, h:16/8, s:1/77184, d:128/128 | 130.62 |
223.84 |
| [fp16\|group\|bshd] b:1, h:32/32, s:1/9600, d:128/128 | 82.65 | 138.44
|
| [fp16\|group\|bhsd] b:1, h:32/32, s:1/9600, d:128/128 | 105.75 |
220.45 |
| [fp16\|group\|bshd] b:1, h:8/1, s:1/401024, d:256/256 | 16.27 | 187.89
|
| [fp16\|group\|bhsd] b:1, h:8/1, s:1/401024, d:256/256 | 16.28 | 188.19
|

## Test Plan

An additional test case is added to the exiting test. It uses seqlen_q =
1, GQA, no mask to trigger the changes
```
ninja test_ck_tile_fmha_fwd_fp16 && bin/test_ck_tile_fmha_fwd_fp16 --gtest_filter="*SplitKV*
ninja test_ck_tile_fmha_fwd_bf16 && bin/test_ck_tile_fmha_fwd_bf16 --gtest_filter="*SplitKV*
```

Manual testing can be done with these commands:
```
bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=32 -h_k=32 -d=64  -s=1 -s_k=$((352 * 128))  -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1
bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=16 -h_k=8  -d=128 -s=1 -s_k=$((603 * 128))  -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1
bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=32 -h_k=32 -d=128 -s=1 -s_k=$((75 * 128))   -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1
bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=8  -h_k=1  -d=256 -s=1 -s_k=$((3133 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1
```

## Test Result

All the tests must pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-06-03 06:16:10 +00:00
..

CK Tile Example Suite

This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.


What is CK Tile?

CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.


Example Index

Example Operation Description
01_fmha Fused Multi-Head Attention Tile-based FMHA with masking, quantization, and epilogue fusion
02_layernorm2d LayerNorm2D Blockwise layer normalization with fusion and quantization
03_gemm GEMM Matrix multiplication with tilewise parallelism
04_img2col im2col Image-to-column transformation for GEMM-based convolution
05_reduce Reduction Tilewise sum, max, mean reductions
06_permute Permute Generic tensor permutation (up to rank-8)
09_topk_softmax TopK-Softmax Rowwise softmax and top-k selection for MoE gating
10_rmsnorm2d RMSNorm2D Root mean square normalization for LLMs
11_add_rmsnorm2d_rdquant Add + RMSNorm2D + RDQuant Fused add, RMSNorm, and rowwise dynamic quantization
12_smoothquant SmoothQuant Per-channel scaling and quantization for int8 inference
13_moe_sorting MoE Sorting Token-to-expert rearrangement for MoE dispatch
14_moe_smoothquant MoE-SmoothQuant Expert-dependent quantization fused with top-k selection
15_fused_moe Fused MoE End-to-end fused MoE block: sorting, group-GEMM, activation, weighting
16_batched_gemm Batched GEMM Parallel computation of multiple GEMMs
17_grouped_gemm Grouped GEMM Multiple independent GEMMs with different shapes
18_flatmm FLATMM Flattened matrix multiplication for packed layouts
19_gemm_multi_d Multi-D GEMM GEMM with multiple side inputs (bias, residual, etc.)
35_batched_transpose Batched Transpose NCHW <-> NHWC and other layout conversions
36_copy Copy Minimal example for tile-based memory movement
37_transpose Block Transpose High-performance tiled transpose for large tensors

Technical Highlights


How to Build & Run

mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j

Each example produces its own executable in build/bin/.


Learning and Extending


References


Back to Composable Kernel Examples