mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
## Motivation
The existing paged-KV attention pipelines (pagedkv, splitkv) support
StreamLLM-style sink tokens — a fixed set of initial tokens kept in
attention alongside the sliding window. The `batch_prefill` pipeline
(chunked-prefill with VLLM-style block tables) previously hardcoded
`kHasSink = false`, making it incompatible with sink-based attention
patterns in LLM serving scenarios.
This PR extends `batch_prefill` to support `kHasSink` and wires it
into `fmha_fwd_runner` for validation against the existing CPU
reference.
## Technical Details
**Pipeline** (`block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp`):
- When `kHasSink`, the K/V loop splits into a sink phase [0,
sink_seq_end)
and a window phase [seqlen_k_start, seqlen_k_end), mirroring pagedkv.
- K advance at the sink→window transition jumps
`seqlen_k_start - sink_seq_end + kN0` to bridge the gap.
- V scatter-gather offsets are re-initialized at the transition to fix a
window mismatch bug: V was lagging kN0 behind K after the large jump,
loading from the wrong sequence position.
- Bias window, dropout seq_offset, and mask type (LogitsSinkMask)
updated
for sink-awareness.
**Traits / codegen** (`tile_fmha_traits.hpp`, `fmha_fwd.hpp`,
`fmha_batch_prefill.py`):
- `TileFmhaBatchPrefillTraits` gains `kHasSink_` (was hardcoded
`false`).
- Codegen adds `F_sink` field; skips batch-mode kernels (group mode
required).
- CMake test filter broadened from 9 → 33 instances covering
fp16/bf16 × mask/nmask × lse/nlse × sink/nsink.
**Runner** (`fmha_fwd_runner.hpp`, `CMakeLists.txt`):
- `fmha_batch_prefill()` dispatched from `run_fwd` when:
group mode + paged KV + num_splits == 1.
- K/V strides corrected for runner's [num_pages, nhead_k,
page_block_size, hdim] layout.
- `page_block_size % 128` check relaxed: batch_prefill supports ps=16.
- CPU reference paged-KV reordering guards extended with
`CK_TILE_FMHA_FWD_BATCH_PREFILL_API`.
## Test Plan
Build with `-DFMHA_FWD_ENABLE_APIS="fwd;batch_prefill"`, run
`tile_example_fmha_fwd` in group mode with page_block_size=16.
Test matrix:
- Mask: no-mask, causal, sliding window
- Sink: nsink, sink=1..128
- dtype: fp16, bf16
- LSE output: on/off
- seqlen ∈ {512,1024,2048,4096} × window ∈ {32,256,512,1024}
- GQA, chunked prefill, large batch×seqlen
- page_block_size: 16, 32
## Test Result
171 test cases, all valid:y:
- nmask + nsink: ✓
- causal + nsink: ✓
- causal + sink=8: ✓
- sliding window + sink=8 (d=128, d=256): ✓
- bf16, LSE output, GQA: ✓
## Submission Checklist
- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
CK Tile Example Suite
This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.
What is CK Tile?
CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.
Example Index
| Example | Operation | Description |
|---|---|---|
| 01_fmha | Fused Multi-Head Attention | Tile-based FMHA with masking, quantization, and epilogue fusion |
| 02_layernorm2d | LayerNorm2D | Blockwise layer normalization with fusion and quantization |
| 03_gemm | GEMM | Matrix multiplication with tilewise parallelism |
| 04_img2col | im2col | Image-to-column transformation for GEMM-based convolution |
| 05_reduce | Reduction | Tilewise sum, max, mean reductions |
| 06_permute | Permute | Generic tensor permutation (up to rank-8) |
| 09_topk_softmax | TopK-Softmax | Rowwise softmax and top-k selection for MoE gating |
| 10_rmsnorm2d | RMSNorm2D | Root mean square normalization for LLMs |
| 11_add_rmsnorm2d_rdquant | Add + RMSNorm2D + RDQuant | Fused add, RMSNorm, and rowwise dynamic quantization |
| 12_smoothquant | SmoothQuant | Per-channel scaling and quantization for int8 inference |
| 13_moe_sorting | MoE Sorting | Token-to-expert rearrangement for MoE dispatch |
| 14_moe_smoothquant | MoE-SmoothQuant | Expert-dependent quantization fused with top-k selection |
| 15_fused_moe | Fused MoE | End-to-end fused MoE block: sorting, group-GEMM, activation, weighting |
| 16_batched_gemm | Batched GEMM | Parallel computation of multiple GEMMs |
| 17_grouped_gemm | Grouped GEMM | Multiple independent GEMMs with different shapes |
| 18_flatmm | FLATMM | Flattened matrix multiplication for packed layouts |
| 19_gemm_multi_d | Multi-D GEMM | GEMM with multiple side inputs (bias, residual, etc.) |
| 35_batched_transpose | Batched Transpose | NCHW <-> NHWC and other layout conversions |
| 36_copy | Copy | Minimal example for tile-based memory movement |
| 37_transpose | Block Transpose | High-performance tiled transpose for large tensors |
Technical Highlights
- Tile Distribution: See
include/ck_tile/tile_program/tile_distribution/for mapping tiles to thread blocks. - Block Tile Pipelines: See
include/ck_tile/tile_program/block_tile_pipeline/for memory/computation pipelines. - Policies and Utilities: Many examples use custom policies for tile/block size and memory access.
How to Build & Run
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j
Each example produces its own executable in build/bin/.
Learning and Extending
- Start Simple: Try 03_gemm or 36_copy to learn tile basics.
- Explore Fusion: See 11_add_rmsnorm2d_rdquant, 15_fused_moe, or 14_moe_smoothquant for advanced fusion.
- Experiment: Modify tile sizes, layouts, or pipelines to explore performance and flexibility.