mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
End-to-end split-KV (FlashDecoding-style) for the CK unified attention
kernel. The host launches a single 3D grid with z == num_splits; each
CTA computes its KV-range slice and writes a normalized (o_acc, lse)
partial to FP32 workspaces, which the caller reduces into the final
output.
Pipeline changes:
- operator() returns ck_tile::make_tuple(o_acc, lse) instead of just
o_acc. The masked-empty early-exit returns lse = -inf so downstream
combine weighs the partial as zero.
- LSE is built in the natural-log domain from the pipeline's *unscaled*
rowmax: lse = (scale_s / log2(e)) * m + log(l). Previously we used
m / log2(e) + log(l), which dropped the per-head scale and produced
LSE values ~1/scale too large.
- Fix post-process parity: which SP register is left in the
alu0-done/alu1-pending state at loop exit depends on the parity of
the *iteration count* (= num_total_loop - num_blocks_start), not on
num_total_loop alone. For non-split (num_blocks_start == 0) the two
parities coincide; for splits starting at an odd tile they don't.
- Fix split-KV page-table offset: num_blocks_start is counted in
kPageBlockSize-sized tiles, but block_tables is indexed in
page_size-sized pages — shifting block_table_offset by num_blocks_start
reads the wrong pages whenever kPageBlockSize != page_size. Replaced
with split_token_offset = num_blocks_start * kPageBlockSize added to
logical_token before /page_size, so the page lookup uses the absolute
token position.
Kernel + dispatcher:
- Drop kargs.i_split; each CTA reads i_split = blockIdx.z.
- GridSize{2D,Decode} now take num_splits and add it as the z-dim
(defaults to 1, so non-split callers see dim3(..., 1, 1)).
- New write path: when num_splits > 1, the kernel skips the user
epilogue and instead writes the FP32 (o_acc, lse) tile pair into
workspace tensors at [head, split, batch_start_token, ...] using
Default2DEpilogue (UseRawStore=true) for o_acc and store_tile for
lse. Host strides come from kargs.
Co-authored-by: Cursor <cursoragent@cursor.com>
CK Tile Example Suite
This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.
What is CK Tile?
CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.
Example Index
| Example | Operation | Description |
|---|---|---|
| 01_fmha | Fused Multi-Head Attention | Tile-based FMHA with masking, quantization, and epilogue fusion |
| 02_layernorm2d | LayerNorm2D | Blockwise layer normalization with fusion and quantization |
| 03_gemm | GEMM | Matrix multiplication with tilewise parallelism |
| 04_img2col | im2col | Image-to-column transformation for GEMM-based convolution |
| 05_reduce | Reduction | Tilewise sum, max, mean reductions |
| 06_permute | Permute | Generic tensor permutation (up to rank-8) |
| 09_topk_softmax | TopK-Softmax | Rowwise softmax and top-k selection for MoE gating |
| 10_rmsnorm2d | RMSNorm2D | Root mean square normalization for LLMs |
| 11_add_rmsnorm2d_rdquant | Add + RMSNorm2D + RDQuant | Fused add, RMSNorm, and rowwise dynamic quantization |
| 12_smoothquant | SmoothQuant | Per-channel scaling and quantization for int8 inference |
| 13_moe_sorting | MoE Sorting | Token-to-expert rearrangement for MoE dispatch |
| 14_moe_smoothquant | MoE-SmoothQuant | Expert-dependent quantization fused with top-k selection |
| 15_fused_moe | Fused MoE | End-to-end fused MoE block: sorting, group-GEMM, activation, weighting |
| 16_batched_gemm | Batched GEMM | Parallel computation of multiple GEMMs |
| 17_grouped_gemm | Grouped GEMM | Multiple independent GEMMs with different shapes |
| 18_flatmm | FLATMM | Flattened matrix multiplication for packed layouts |
| 19_gemm_multi_d | Multi-D GEMM | GEMM with multiple side inputs (bias, residual, etc.) |
| 35_batched_transpose | Batched Transpose | NCHW <-> NHWC and other layout conversions |
| 36_copy | Copy | Minimal example for tile-based memory movement |
| 37_transpose | Block Transpose | High-performance tiled transpose for large tensors |
Technical Highlights
- Tile Distribution: See
include/ck_tile/tile_program/tile_distribution/for mapping tiles to thread blocks. - Block Tile Pipelines: See
include/ck_tile/tile_program/block_tile_pipeline/for memory/computation pipelines. - Policies and Utilities: Many examples use custom policies for tile/block size and memory access.
How to Build & Run
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j
Each example produces its own executable in build/bin/.
Learning and Extending
- Start Simple: Try 03_gemm or 36_copy to learn tile basics.
- Explore Fusion: See 11_add_rmsnorm2d_rdquant, 15_fused_moe, or 14_moe_smoothquant for advanced fusion.
- Experiment: Modify tile sizes, layouts, or pipelines to explore performance and flexibility.