Files
composable_kernel/example/ck_tile
Linjun-AMD d22aafb48b [rocm-libraries] ROCm/rocm-libraries#6479 (commit 0705c2d)
CK][fmha] Add StreamLLM sink support to batch_prefill
 pipeline (#6479)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

The existing paged-KV attention pipelines (pagedkv, splitkv) support
  StreamLLM-style sink tokens — a fixed set of initial tokens kept in
  attention alongside the sliding window. The `batch_prefill` pipeline
  (chunked-prefill with VLLM-style block tables) previously hardcoded
  `kHasSink = false`, making it incompatible with sink-based attention
  patterns in LLM serving scenarios.

  This PR extends `batch_prefill` to support `kHasSink` and wires it
into `fmha_fwd_runner` for validation against the existing CPU
reference.

## Technical Details

 **Pipeline** (`block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp`):
- When `kHasSink`, the K/V loop splits into a sink phase [0,
sink_seq_end)
and a window phase [seqlen_k_start, seqlen_k_end), mirroring pagedkv.
  - K advance at the sink→window transition jumps
    `seqlen_k_start - sink_seq_end + kN0` to bridge the gap.
- V scatter-gather offsets are re-initialized at the transition to fix a
window mismatch bug: V was lagging kN0 behind K after the large jump,
    loading from the wrong sequence position.
- Bias window, dropout seq_offset, and mask type (LogitsSinkMask)
updated
    for sink-awareness.

**Traits / codegen** (`tile_fmha_traits.hpp`, `fmha_fwd.hpp`,
`fmha_batch_prefill.py`):
- `TileFmhaBatchPrefillTraits` gains `kHasSink_` (was hardcoded
`false`).
- Codegen adds `F_sink` field; skips batch-mode kernels (group mode
required).
  - CMake test filter broadened from 9 → 33 instances covering
    fp16/bf16 × mask/nmask × lse/nlse × sink/nsink.

  **Runner** (`fmha_fwd_runner.hpp`, `CMakeLists.txt`):
  - `fmha_batch_prefill()` dispatched from `run_fwd` when:
    group mode + paged KV + num_splits == 1.
- K/V strides corrected for runner's [num_pages, nhead_k,
page_block_size, hdim] layout.
  - `page_block_size % 128` check relaxed: batch_prefill supports ps=16.
  - CPU reference paged-KV reordering guards extended with
    `CK_TILE_FMHA_FWD_BATCH_PREFILL_API`.

## Test Plan

Build with `-DFMHA_FWD_ENABLE_APIS="fwd;batch_prefill"`, run
  `tile_example_fmha_fwd` in group mode with page_block_size=16.

  Test matrix:
  - Mask: no-mask, causal, sliding window
  - Sink: nsink, sink=1..128
  - dtype: fp16, bf16
  - LSE output: on/off
  - seqlen ∈ {512,1024,2048,4096} × window ∈ {32,256,512,1024}
  - GQA, chunked prefill, large batch×seqlen
  - page_block_size: 16, 32

## Test Result

171 test cases, all valid:y:
  - nmask + nsink: ✓
  - causal + nsink: ✓
  - causal + sink=8: ✓
  - sliding window + sink=8 (d=128, d=256): ✓
  - bf16, LSE output, GQA: ✓

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-21 11:05:12 +00:00
..

CK Tile Example Suite

This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.


What is CK Tile?

CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.


Example Index

Example Operation Description
01_fmha Fused Multi-Head Attention Tile-based FMHA with masking, quantization, and epilogue fusion
02_layernorm2d LayerNorm2D Blockwise layer normalization with fusion and quantization
03_gemm GEMM Matrix multiplication with tilewise parallelism
04_img2col im2col Image-to-column transformation for GEMM-based convolution
05_reduce Reduction Tilewise sum, max, mean reductions
06_permute Permute Generic tensor permutation (up to rank-8)
09_topk_softmax TopK-Softmax Rowwise softmax and top-k selection for MoE gating
10_rmsnorm2d RMSNorm2D Root mean square normalization for LLMs
11_add_rmsnorm2d_rdquant Add + RMSNorm2D + RDQuant Fused add, RMSNorm, and rowwise dynamic quantization
12_smoothquant SmoothQuant Per-channel scaling and quantization for int8 inference
13_moe_sorting MoE Sorting Token-to-expert rearrangement for MoE dispatch
14_moe_smoothquant MoE-SmoothQuant Expert-dependent quantization fused with top-k selection
15_fused_moe Fused MoE End-to-end fused MoE block: sorting, group-GEMM, activation, weighting
16_batched_gemm Batched GEMM Parallel computation of multiple GEMMs
17_grouped_gemm Grouped GEMM Multiple independent GEMMs with different shapes
18_flatmm FLATMM Flattened matrix multiplication for packed layouts
19_gemm_multi_d Multi-D GEMM GEMM with multiple side inputs (bias, residual, etc.)
35_batched_transpose Batched Transpose NCHW <-> NHWC and other layout conversions
36_copy Copy Minimal example for tile-based memory movement
37_transpose Block Transpose High-performance tiled transpose for large tensors

Technical Highlights


How to Build & Run

mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j

Each example produces its own executable in build/bin/.


Learning and Extending


References


Back to Composable Kernel Examples