mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
## Motivation
The CK batch prefill kernel previously failed (silent overflow + page
faults) when the KV cache exceeded 2 GB, blocking long-context inference
workloads (e.g., 128K+ token contexts with paged KV).
Two distinct failure modes were addressed:
1. **>4GB SRD overflow (`page_size < kN0`):** The SRD
`buffer_load_dwordx4` path uses a 32-bit `voffset` register; for small
page sizes the rebased SRD spans the full KV pool and the offset wraps
past 2 GB, corrupting K/V loads.
2. **gfx950 page-table fault (`page_size >= kN0`):** On CDNA4 the
hardware validates the **full SRD `num_records` range** against
page-table permissions (CDNA3 only checks per-instruction `voffset`).
After per-tile SRD rebase, an un-trimmed `num_records` field extends
past the live page and faults on freed/protected memory.
## Technical Details
**Two-mode `tile_scatter_gather` selected by the `kUseGlobalLoad`
template parameter:**
| Case | `page_size` | KV cache size | Mode | Load path | Addressing |
|---|---|---|---|---|---|
| 1 | `>= kN0` (large pages) | any | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, bounded by per-page rebase |
| 2 | `< kN0` (small pages) | `<= 2 GB` | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, fits in INT32 byte range |
| 3 | `< kN0` (small pages) | `> 2 GB` | Global-load
(`kUseGlobalLoad=true`) | `async_load_tile_raw_flat` (K) +
`load_tile_flat` (V) | 64-bit |
**Dispatch:** the auto-gen API layer (`fmha_batch_prefill.py`) selects
the kernel instantiation at launch from `(page_block_size,
num_total_pages * batch_stride_k * kElementBytes)`, so the small-page
penalty is paid only when correctness requires it.
**gfx950 SRD `num_records` trimming:** in the K and V rebase lambdas of
`block_fmha_batch_prefill_pipeline_qr_ks_vs_async`,
`set_bottom_tensor_view_buffer_size(page_stride_k/v)` is called after
each rebase to constrain `num_records` to the live page. Required for
CDNA4 page-table validation; harmless on CDNA3.
**Pipeline sync for the global-load path:**
- V uses synchronous `load_tile_flat`; K uses
`async_load_tile_raw_flat`.
- `v_physical_pages_current` is double-buffered so the V flat load
doesn't race against the next iteration's K rebase computation.
**Arch guards:** `global_load_lds` intrinsics are gated to `__gfx94__` /
`__gfx950__` (CDNA3+). Other architectures hit a `dependent_false`
static_assert with a descriptive message.
**Device-side assertion convention:** SRD setters use
`__builtin_assume(cond)` (hint-only) rather than `<cassert>`'s
`assert()`. The latter introduces an `__assert_fail` call whose register
pressure scatters the K-SRD scalar register window across conditional
branches, corrupting `buffer_load_dwordx4` on gfx950.
## Test Plan
Tested on both MI308 (gfx942) and MI355 (gfx950) via the aiter wrapper
test suite. All coverage lives in **`op_tests/test_batch_prefill.py`**:
- **Functional matrix (96 cases)** — `test_batch_prefill`: `page_size ∈
{1, 16, 1024}` × `kv_layout ∈ {linear, vectorized}` × `dtype ∈ {bf16,
fp8 quant variants}` × `causal` × `soft_cap` × `LSE` × `batch_size ∈ {1,
4}` (parametrized to exercise per-sequence SRD rebase across batch
boundaries).
- **>2 GB coverage** — `test_batch_prefill_large_kvcache`: extended to
allocate a 5 GB+ KV cache pool and exercise both `kUseGlobalLoad=true`
(small-page) and `kUseGlobalLoad=false` (large-page rebase) paths.
Includes both single-batch and multi-batch (`batch_size=4`) cases to
exercise per-sequence SRD rebase across the >2 GB pool.
- Numerical reference: PyTorch SDPA, per-batch loop with `atol` / `rtol`
from the existing batch prefill test harness.
## Test Result
| Arch | `test_batch_prefill` | `test_batch_prefill_large_kvcache` (>2
GB) |
|------|----------------------|---------------------|
| MI308 (gfx942) | All passed | Passed |
| MI355 (gfx950) | All passed | Passed |
**Performance impact (gfx950, hot SRD path):**
- +2.67% kernel-time on `seqlen=1024 / page_sz=1024 / bf16 / sglang /
causal / soft_cap=30`, attributable in full to the two
`set_bottom_tensor_view_buffer_size` calls in the K/V rebase lambdas
(5-run median, signal/noise ≈ 9×).
- This cost is **mandatory for gfx950 correctness** on >2 GB workloads —
removing the setters re-introduces page-faults.
- gfx942: 0 regressions in the same range (all configs ≤ +0.97%).
## Submission Checklist
- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
CK Tile Example Suite
This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.
What is CK Tile?
CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.
Example Index
| Example | Operation | Description |
|---|---|---|
| 01_fmha | Fused Multi-Head Attention | Tile-based FMHA with masking, quantization, and epilogue fusion |
| 02_layernorm2d | LayerNorm2D | Blockwise layer normalization with fusion and quantization |
| 03_gemm | GEMM | Matrix multiplication with tilewise parallelism |
| 04_img2col | im2col | Image-to-column transformation for GEMM-based convolution |
| 05_reduce | Reduction | Tilewise sum, max, mean reductions |
| 06_permute | Permute | Generic tensor permutation (up to rank-8) |
| 09_topk_softmax | TopK-Softmax | Rowwise softmax and top-k selection for MoE gating |
| 10_rmsnorm2d | RMSNorm2D | Root mean square normalization for LLMs |
| 11_add_rmsnorm2d_rdquant | Add + RMSNorm2D + RDQuant | Fused add, RMSNorm, and rowwise dynamic quantization |
| 12_smoothquant | SmoothQuant | Per-channel scaling and quantization for int8 inference |
| 13_moe_sorting | MoE Sorting | Token-to-expert rearrangement for MoE dispatch |
| 14_moe_smoothquant | MoE-SmoothQuant | Expert-dependent quantization fused with top-k selection |
| 15_fused_moe | Fused MoE | End-to-end fused MoE block: sorting, group-GEMM, activation, weighting |
| 16_batched_gemm | Batched GEMM | Parallel computation of multiple GEMMs |
| 17_grouped_gemm | Grouped GEMM | Multiple independent GEMMs with different shapes |
| 18_flatmm | FLATMM | Flattened matrix multiplication for packed layouts |
| 19_gemm_multi_d | Multi-D GEMM | GEMM with multiple side inputs (bias, residual, etc.) |
| 35_batched_transpose | Batched Transpose | NCHW <-> NHWC and other layout conversions |
| 36_copy | Copy | Minimal example for tile-based memory movement |
| 37_transpose | Block Transpose | High-performance tiled transpose for large tensors |
Technical Highlights
- Tile Distribution: See
include/ck_tile/tile_program/tile_distribution/for mapping tiles to thread blocks. - Block Tile Pipelines: See
include/ck_tile/tile_program/block_tile_pipeline/for memory/computation pipelines. - Policies and Utilities: Many examples use custom policies for tile/block size and memory access.
How to Build & Run
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j
Each example produces its own executable in build/bin/.
Learning and Extending
- Start Simple: Try 03_gemm or 36_copy to learn tile basics.
- Explore Fusion: See 11_add_rmsnorm2d_rdquant, 15_fused_moe, or 14_moe_smoothquant for advanced fusion.
- Experiment: Modify tile sizes, layouts, or pipelines to explore performance and flexibility.