mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
GetWorkspaceDeviceSizeUpperBound was computing
max_batch * nhead_q * max_seqlen_q * hdim_q
in non-deterministic group mode, but PrepareWorkspaceHost actually returns
nhead_q * seqstart_q[batch] * hdim_q
i.e. it scales with the sum of *padded* per-batch seqlen_q, not max_batch
times the *logical* max. When per-batch padding makes seqstart_q[batch]
exceed max_batch * max_seqlen_q the launcher under-allocates dq_acc, the
kernel writes past the buffer, and tests see either ~42% wrong QGrad
values or a GPU page fault (e.g. test_ck_tile_fmha_bwd_bf16
QKVPadding/23,24,26 corrupt; /27 page-faults).
Fix: replace the (max_batch, max_seqlen_q) pair with a single
total_seqlen_q_padded parameter holding the true total padded q tokens.
Launcher derives it from the trait (group: t.seqlen_q already is the
padded total; batch: t.batch * t.seqlen_q). The four mode formulas
collapse to one:
size = nhead_q * nsplits_factor * total_seqlen_q_padded * hdim_q
where nsplits_factor is 1 for non-deterministic, ceil(max_seqlen_k, kN0)
for deterministic group, and the persistent worker computation for
deterministic non-group (the only branch that still needs max_batch).
No caller-side API change: FA, AITER and the CK runner already pass
q.shape[0] (the padded total) as traits.seqlen_q in group mode.
Verified on gfx1201: full test_ck_tile_fmha_bwd_{bf16,fp16} 672/672 PASS,
0 fail, 0 crash (was 27/28 QKVPadding fails + 1 GPU illegal access).
CK Tile Example Suite
This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.
What is CK Tile?
CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.
Example Index
| Example | Operation | Description |
|---|---|---|
| 01_fmha | Fused Multi-Head Attention | Tile-based FMHA with masking, quantization, and epilogue fusion |
| 02_layernorm2d | LayerNorm2D | Blockwise layer normalization with fusion and quantization |
| 03_gemm | GEMM | Matrix multiplication with tilewise parallelism |
| 04_img2col | im2col | Image-to-column transformation for GEMM-based convolution |
| 05_reduce | Reduction | Tilewise sum, max, mean reductions |
| 06_permute | Permute | Generic tensor permutation (up to rank-8) |
| 09_topk_softmax | TopK-Softmax | Rowwise softmax and top-k selection for MoE gating |
| 10_rmsnorm2d | RMSNorm2D | Root mean square normalization for LLMs |
| 11_add_rmsnorm2d_rdquant | Add + RMSNorm2D + RDQuant | Fused add, RMSNorm, and rowwise dynamic quantization |
| 12_smoothquant | SmoothQuant | Per-channel scaling and quantization for int8 inference |
| 13_moe_sorting | MoE Sorting | Token-to-expert rearrangement for MoE dispatch |
| 14_moe_smoothquant | MoE-SmoothQuant | Expert-dependent quantization fused with top-k selection |
| 15_fused_moe | Fused MoE | End-to-end fused MoE block: sorting, group-GEMM, activation, weighting |
| 16_batched_gemm | Batched GEMM | Parallel computation of multiple GEMMs |
| 17_grouped_gemm | Grouped GEMM | Multiple independent GEMMs with different shapes |
| 18_flatmm | FLATMM | Flattened matrix multiplication for packed layouts |
| 19_gemm_multi_d | Multi-D GEMM | GEMM with multiple side inputs (bias, residual, etc.) |
| 35_batched_transpose | Batched Transpose | NCHW <-> NHWC and other layout conversions |
| 36_copy | Copy | Minimal example for tile-based memory movement |
| 37_transpose | Block Transpose | High-performance tiled transpose for large tensors |
Technical Highlights
- Tile Distribution: See
include/ck_tile/tile_program/tile_distribution/for mapping tiles to thread blocks. - Block Tile Pipelines: See
include/ck_tile/tile_program/block_tile_pipeline/for memory/computation pipelines. - Policies and Utilities: Many examples use custom policies for tile/block size and memory access.
How to Build & Run
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j
Each example produces its own executable in build/bin/.
Learning and Extending
- Start Simple: Try 03_gemm or 36_copy to learn tile basics.
- Explore Fusion: See 11_add_rmsnorm2d_rdquant, 15_fused_moe, or 14_moe_smoothquant for advanced fusion.
- Experiment: Modify tile sizes, layouts, or pipelines to explore performance and flexibility.