mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
[CK_TILE] fix(fmha): support >2GB KV cache in batch prefill via template dispatch (#6653) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The CK batch prefill kernel previously failed (silent overflow + page faults) when the KV cache exceeded 2 GB, blocking long-context inference workloads (e.g., 128K+ token contexts with paged KV). Two distinct failure modes were addressed: 1. **>4GB SRD overflow (`page_size < kN0`):** The SRD `buffer_load_dwordx4` path uses a 32-bit `voffset` register; for small page sizes the rebased SRD spans the full KV pool and the offset wraps past 2 GB, corrupting K/V loads. 2. **gfx950 page-table fault (`page_size >= kN0`):** On CDNA4 the hardware validates the **full SRD `num_records` range** against page-table permissions (CDNA3 only checks per-instruction `voffset`). After per-tile SRD rebase, an un-trimmed `num_records` field extends past the live page and faults on freed/protected memory. ## Technical Details **Two-mode `tile_scatter_gather` selected by the `kUseGlobalLoad` template parameter:** | Case | `page_size` | KV cache size | Mode | Load path | Addressing | |---|---|---|---|---|---| | 1 | `>= kN0` (large pages) | any | SRD (`kUseGlobalLoad=false`) | `buffer_load_dwordx4` | 32-bit `voffset`, bounded by per-page rebase | | 2 | `< kN0` (small pages) | `<= 2 GB` | SRD (`kUseGlobalLoad=false`) | `buffer_load_dwordx4` | 32-bit `voffset`, fits in INT32 byte range | | 3 | `< kN0` (small pages) | `> 2 GB` | Global-load (`kUseGlobalLoad=true`) | `async_load_tile_raw_flat` (K) + `load_tile_flat` (V) | 64-bit | **Dispatch:** the auto-gen API layer (`fmha_batch_prefill.py`) selects the kernel instantiation at launch from `(page_block_size, num_total_pages * batch_stride_k * kElementBytes)`, so the small-page penalty is paid only when correctness requires it. **gfx950 SRD `num_records` trimming:** in the K and V rebase lambdas of `block_fmha_batch_prefill_pipeline_qr_ks_vs_async`, `set_bottom_tensor_view_buffer_size(page_stride_k/v)` is called after each rebase to constrain `num_records` to the live page. Required for CDNA4 page-table validation; harmless on CDNA3. **Pipeline sync for the global-load path:** - V uses synchronous `load_tile_flat`; K uses `async_load_tile_raw_flat`. - `v_physical_pages_current` is double-buffered so the V flat load doesn't race against the next iteration's K rebase computation. **Arch guards:** `global_load_lds` intrinsics are gated to `__gfx94__` / `__gfx950__` (CDNA3+). Other architectures hit a `dependent_false` static_assert with a descriptive message. **Device-side assertion convention:** SRD setters use `__builtin_assume(cond)` (hint-only) rather than `<cassert>`'s `assert()`. The latter introduces an `__assert_fail` call whose register pressure scatters the K-SRD scalar register window across conditional branches, corrupting `buffer_load_dwordx4` on gfx950. ## Test Plan Tested on both MI308 (gfx942) and MI355 (gfx950) via the aiter wrapper test suite. All coverage lives in **`op_tests/test_batch_prefill.py`**: - **Functional matrix (96 cases)** — `test_batch_prefill`: `page_size ∈ {1, 16, 1024}` × `kv_layout ∈ {linear, vectorized}` × `dtype ∈ {bf16, fp8 quant variants}` × `causal` × `soft_cap` × `LSE` × `batch_size ∈ {1, 4}` (parametrized to exercise per-sequence SRD rebase across batch boundaries). - **>2 GB coverage** — `test_batch_prefill_large_kvcache`: extended to allocate a 5 GB+ KV cache pool and exercise both `kUseGlobalLoad=true` (small-page) and `kUseGlobalLoad=false` (large-page rebase) paths. Includes both single-batch and multi-batch (`batch_size=4`) cases to exercise per-sequence SRD rebase across the >2 GB pool. - Numerical reference: PyTorch SDPA, per-batch loop with `atol` / `rtol` from the existing batch prefill test harness. ## Test Result | Arch | `test_batch_prefill` | `test_batch_prefill_large_kvcache` (>2 GB) | |------|----------------------|---------------------| | MI308 (gfx942) | All passed | Passed | | MI355 (gfx950) | All passed | Passed | **Performance impact (gfx950, hot SRD path):** - +2.67% kernel-time on `seqlen=1024 / page_sz=1024 / bf16 / sglang / causal / soft_cap=30`, attributable in full to the two `set_bottom_tensor_view_buffer_size` calls in the K/V rebase lambdas (5-run median, signal/noise ≈ 9×). - This cost is **mandatory for gfx950 correctness** on >2 GB workloads — removing the setters re-introduces page-faults. - gfx942: 0 regressions in the same range (all configs ≤ +0.97%). ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.