Files
composable_kernel/include/ck_tile
Jeff Huang fdf4bb7fcc [rocm-libraries] ROCm/rocm-libraries#6653 (commit 1df887e)
[CK_TILE] fix(fmha): support >2GB KV cache in batch prefill
 via template dispatch (#6653)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

The CK batch prefill kernel previously failed (silent overflow + page
faults) when the KV cache exceeded 2 GB, blocking long-context inference
workloads (e.g., 128K+ token contexts with paged KV).

Two distinct failure modes were addressed:

1. **>4GB SRD overflow (`page_size < kN0`):** The SRD
`buffer_load_dwordx4` path uses a 32-bit `voffset` register; for small
page sizes the rebased SRD spans the full KV pool and the offset wraps
past 2 GB, corrupting K/V loads.
2. **gfx950 page-table fault (`page_size >= kN0`):** On CDNA4 the
hardware validates the **full SRD `num_records` range** against
page-table permissions (CDNA3 only checks per-instruction `voffset`).
After per-tile SRD rebase, an un-trimmed `num_records` field extends
past the live page and faults on freed/protected memory.

## Technical Details

**Two-mode `tile_scatter_gather` selected by the `kUseGlobalLoad`
template parameter:**

| Case | `page_size` | KV cache size | Mode | Load path | Addressing |
|---|---|---|---|---|---|
| 1 | `>= kN0` (large pages) | any | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, bounded by per-page rebase |
| 2 | `< kN0` (small pages) | `<= 2 GB` | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, fits in INT32 byte range |
| 3 | `< kN0` (small pages) | `> 2 GB` | Global-load
(`kUseGlobalLoad=true`) | `async_load_tile_raw_flat` (K) +
`load_tile_flat` (V) | 64-bit |

**Dispatch:** the auto-gen API layer (`fmha_batch_prefill.py`) selects
the kernel instantiation at launch from `(page_block_size,
num_total_pages * batch_stride_k * kElementBytes)`, so the small-page
penalty is paid only when correctness requires it.

**gfx950 SRD `num_records` trimming:** in the K and V rebase lambdas of
`block_fmha_batch_prefill_pipeline_qr_ks_vs_async`,
`set_bottom_tensor_view_buffer_size(page_stride_k/v)` is called after
each rebase to constrain `num_records` to the live page. Required for
CDNA4 page-table validation; harmless on CDNA3.

**Pipeline sync for the global-load path:**
- V uses synchronous `load_tile_flat`; K uses
`async_load_tile_raw_flat`.
- `v_physical_pages_current` is double-buffered so the V flat load
doesn't race against the next iteration's K rebase computation.

**Arch guards:** `global_load_lds` intrinsics are gated to `__gfx94__` /
`__gfx950__` (CDNA3+). Other architectures hit a `dependent_false`
static_assert with a descriptive message.

**Device-side assertion convention:** SRD setters use
`__builtin_assume(cond)` (hint-only) rather than `<cassert>`'s
`assert()`. The latter introduces an `__assert_fail` call whose register
pressure scatters the K-SRD scalar register window across conditional
branches, corrupting `buffer_load_dwordx4` on gfx950.

## Test Plan

Tested on both MI308 (gfx942) and MI355 (gfx950) via the aiter wrapper
test suite. All coverage lives in **`op_tests/test_batch_prefill.py`**:

- **Functional matrix (96 cases)** — `test_batch_prefill`: `page_size ∈
{1, 16, 1024}` × `kv_layout ∈ {linear, vectorized}` × `dtype ∈ {bf16,
fp8 quant variants}` × `causal` × `soft_cap` × `LSE` × `batch_size ∈ {1,
4}` (parametrized to exercise per-sequence SRD rebase across batch
boundaries).
- **>2 GB coverage** — `test_batch_prefill_large_kvcache`: extended to
allocate a 5 GB+ KV cache pool and exercise both `kUseGlobalLoad=true`
(small-page) and `kUseGlobalLoad=false` (large-page rebase) paths.
Includes both single-batch and multi-batch (`batch_size=4`) cases to
exercise per-sequence SRD rebase across the >2 GB pool.
- Numerical reference: PyTorch SDPA, per-batch loop with `atol` / `rtol`
from the existing batch prefill test harness.

## Test Result

| Arch | `test_batch_prefill` | `test_batch_prefill_large_kvcache` (>2
GB) |
|------|----------------------|---------------------|
| MI308 (gfx942) | All passed | Passed |
| MI355 (gfx950) | All passed | Passed |

**Performance impact (gfx950, hot SRD path):**
- +2.67% kernel-time on `seqlen=1024 / page_sz=1024 / bf16 / sglang /
causal / soft_cap=30`, attributable in full to the two
`set_bottom_tensor_view_buffer_size` calls in the K/V rebase lambdas
(5-run median, signal/noise ≈ 9×).
- This cost is **mandatory for gfx950 correctness** on >2 GB workloads —
removing the setters re-introduces page-faults.
- gfx942: 0 regressions in the same range (all configs ≤ +0.97%).

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-23 23:09:25 +00:00
..

Back to the main page

Composable Kernel Tile

concept

ck_tile provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator

  • tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time.
  • tile-based programming model, including tile-level api and the concept of distributed tensor.

ck_tile is independently from the old ck, located under /include/ck_tile. You don't need to include anything from old CK, ck_tile has similiar (indeed almost the same) implementations for users to build operators. We will have a transition period to pull everything from old ck into ck_tile, stay tuned.

component

ck_tile is splitted into several componenets including core, host, ops/gemm, ops/fmha... each component you only need to include a single header (e.g #include "ck_tile/core.hpp", #include "ck_tile/ops/fmha.hpp") then you are able to use the function/structure inside (different from old ck)

[core]
ck_tile/core contains all the basic data structure and function to build the kernel, you can only include this header and build your own operators that utilizing all the basic building blocks introduced in ck.

core/container

  • array, store runtime variables with fixed length (tensor index, register buffer, etc...)
  • tuple, same as std::tuple, hold different type of data, and one of the solution to achieve multiple buffer.
  • sequence, compile time integer sequence used to build various internal structures, or to describe tile size
  • other convenient structure build on top of above 3

core/numeric

  • gpu data type like fp16_t, bf16_t, fp8_t... and the conversion between each other
  • constexpr integer similiar to std::integral_constant to be used as compile time integer.
  • math functions and numeric utilities

core/algorithm

  • coordinate transformation system, used to build tensor transform and compile time indexing. This is the core idea introduced in old ck to describe how a tensor is build by several basic transform primitives like merge/unmerge/embed etc... and how we indexing into a ND tensor that finally mapped to 1D memory offset.

core/tensor

  • tensor descriptor, to describe how a ND tensor
  • distributed tensor, describe the storage of this tensor, and the distribution of how a collection of threads collaborately work for this tensor.
  • tile level API, including load_tile, store_tile, shuffle_tile, slice_tile, etc...

[host]
ck_tile/host contains all the host side utilities to launch a kernel, create the device buffer, and some reference implementations. This can be used to create examples (like that under ck_tile example folder) and simple executable to invoke this kernel, so if you only need ck_tile to build your own device library then it's OK to not include this. Based on this, it is recommended to include the specific header you needed under this folder to avoid including unwanted headers (e.g, only include ck_tile/host/kernel_launch.hpp), unless you are writing a host executable.

[ops/gemm, ops/fmha, ops/reduce...]
our implementation of different device operators.

  • warp, warp tile level operator
  • block, block tile level operator
  • pipeline, pipeline that can achieve a customized tile level mainloop (or epilogue). By switching different pipeline to the kernel template you can have different kind of pipeline optimizations.
  • kernel, template interface for users to instantiate a particular kernel

[ops/epilogue]
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.

[ref]
reference implementation of cpu or gpu. This folder is supposed to include a specific header on demand.

examples

currently we put all ck_tile related example under /example/ck_tile folder. Please check each example's subfolder.