[CK_TILE] fix(fmha): support >2GB KV cache in batch prefill via template dispatch (#6653)

## Motivation

The CK batch prefill kernel previously failed (silent overflow + page
faults) when the KV cache exceeded 2 GB, blocking long-context inference
workloads (e.g., 128K+ token contexts with paged KV).

Two distinct failure modes were addressed:

1. **>4GB SRD overflow (`page_size < kN0`):** The SRD
`buffer_load_dwordx4` path uses a 32-bit `voffset` register; for small
page sizes the rebased SRD spans the full KV pool and the offset wraps
past 2 GB, corrupting K/V loads.
2. **gfx950 page-table fault (`page_size >= kN0`):** On CDNA4 the
hardware validates the **full SRD `num_records` range** against
page-table permissions (CDNA3 only checks per-instruction `voffset`).
After per-tile SRD rebase, an un-trimmed `num_records` field extends
past the live page and faults on freed/protected memory.

## Technical Details

**Two-mode `tile_scatter_gather` selected by the `kUseGlobalLoad`
template parameter:**

| Case | `page_size` | KV cache size | Mode | Load path | Addressing |
|---|---|---|---|---|---|
| 1 | `>= kN0` (large pages) | any | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, bounded by per-page rebase |
| 2 | `< kN0` (small pages) | `<= 2 GB` | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, fits in INT32 byte range |
| 3 | `< kN0` (small pages) | `> 2 GB` | Global-load
(`kUseGlobalLoad=true`) | `async_load_tile_raw_flat` (K) +
`load_tile_flat` (V) | 64-bit |

**Dispatch:** the auto-gen API layer (`fmha_batch_prefill.py`) selects
the kernel instantiation at launch from `(page_block_size,
num_total_pages * batch_stride_k * kElementBytes)`, so the small-page
penalty is paid only when correctness requires it.

**gfx950 SRD `num_records` trimming:** in the K and V rebase lambdas of
`block_fmha_batch_prefill_pipeline_qr_ks_vs_async`,
`set_bottom_tensor_view_buffer_size(page_stride_k/v)` is called after
each rebase to constrain `num_records` to the live page. Required for
CDNA4 page-table validation; harmless on CDNA3.

**Pipeline sync for the global-load path:**
- V uses synchronous `load_tile_flat`; K uses
`async_load_tile_raw_flat`.
- `v_physical_pages_current` is double-buffered so the V flat load
doesn't race against the next iteration's K rebase computation.

**Arch guards:** `global_load_lds` intrinsics are gated to `__gfx94__` /
`__gfx950__` (CDNA3+). Other architectures hit a `dependent_false`
static_assert with a descriptive message.

**Device-side assertion convention:** SRD setters use
`__builtin_assume(cond)` (hint-only) rather than `<cassert>`'s
`assert()`. The latter introduces an `__assert_fail` call whose register
pressure scatters the K-SRD scalar register window across conditional
branches, corrupting `buffer_load_dwordx4` on gfx950.


## Test Plan

Tested on both MI308 (gfx942) and MI355 (gfx950) via the aiter wrapper
test suite. All coverage lives in **`op_tests/test_batch_prefill.py`**:

- **Functional matrix (96 cases)** — `test_batch_prefill`: `page_size ∈
{1, 16, 1024}` × `kv_layout ∈ {linear, vectorized}` × `dtype ∈ {bf16,
fp8 quant variants}` × `causal` × `soft_cap` × `LSE` × `batch_size ∈ {1,
4}` (parametrized to exercise per-sequence SRD rebase across batch
boundaries).
- **>2 GB coverage** — `test_batch_prefill_large_kvcache`: extended to
allocate a 5 GB+ KV cache pool and exercise both `kUseGlobalLoad=true`
(small-page) and `kUseGlobalLoad=false` (large-page rebase) paths.
Includes both single-batch and multi-batch (`batch_size=4`) cases to
exercise per-sequence SRD rebase across the >2 GB pool.
- Numerical reference: PyTorch SDPA, per-batch loop with `atol` / `rtol`
from the existing batch prefill test harness.

## Test Result

| Arch | `test_batch_prefill` | `test_batch_prefill_large_kvcache` (>2
GB) |
|------|----------------------|---------------------|
| MI308 (gfx942) | All passed | Passed |
| MI355 (gfx950) | All passed | Passed |

**Performance impact (gfx950, hot SRD path):**
- +2.67% kernel-time on `seqlen=1024 / page_sz=1024 / bf16 / sglang /
causal / soft_cap=30`, attributable in full to the two
`set_bottom_tensor_view_buffer_size` calls in the K/V rebase lambdas
(5-run median, signal/noise ≈ 9×).
- This cost is **mandatory for gfx950 correctness** on >2 GB workloads —
removing the setters re-introduces page-faults.
- gfx942: 0 regressions in the same range (all configs ≤ +0.97%).

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Jeff Huang
2026-04-24 07:08:41 +08:00
committed by GitHub
parent d7475e8125
commit b3d45b6fdb
10 changed files with 540 additions and 116 deletions

View File

@@ -1319,6 +1319,87 @@ CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// Flat async load from global memory to LDS using 64-bit global addressing.
// Bypasses the SRD's 32-bit offset limit; required when the KV cache exceeds
// INT32_MAX (2GB) byte offset on the SRD voffset path.
//
// !!! M0 PRECONDITION — IMPLICIT INPUT NOT VISIBLE IN OPERAND LIST !!!
//
// The LDS destination address is taken from M0 (per AMD CDNA3 ISA §10.3:
// `LDS_ADDR = LDSbase + LDSoffset(M0[17:2] * 4) + INST.OFFSET + ThreadID*4`).
// M0 does NOT appear as an operand of these instructions or of the inline
// asm below — the compiler cannot see the dependency. Caller must:
//
// 1. Initialize M0 once before the load loop:
// `m0_set_with_memory(amd_wave_read_first_lane(lds_byte_offset));`
// M0 is SALU-only — `m0_set_with_memory` uses an "s" constraint to
// enforce this. Direct VALU writes to M0 are illegal.
//
// 2. Advance M0 between successive issues:
// `m0_inc_with_memory(size_per_issue);`
// `size_per_issue` MUST be a multiple of 4 — GLOBAL/FLAT LDS path
// only honors M0[17:2]*4 (dword-aligned), so low 2 bits are silently
// dropped (NOTE: this differs from MUBUF buffer_load_lds which uses
// M0[15:0] as a raw byte offset).
//
// 3. Never bundle `m0_inc_with_memory` and the next call to this
// function into a single inline asm. The compiler auto-inserts a
// hazard NOP between an SALU write to M0 and the consuming
// `global_load_lds_*`; bundling bypasses that and may read stale M0.
//
// The "memory" clobber on this asm is load-bearing: it prevents the
// compiler from reordering this load across other M0-touching helpers
// (`m0_set_with_memory` / `m0_inc_with_memory`, also "memory"-clobbered).
//
// Verified instruction emission (HIP 6.4 / clang 19, gfx942 + gfx950):
// `global_load_lds_dwordx4` is a single instruction (encoding 0xDDF48000
// 0x007F0000), NOT software-expanded into 4× dword. Same encoding on both
// arches. The opcode is undocumented in CDNA3 ISA spec §13.6.2 but
// supported by the LLVM AMDGPU backend.
//
// Available on gfx940+ (CDNA3: MI300, MI355, MI350 series).
template <unsigned num_dwords, bool pre_nop = false>
CK_TILE_DEVICE void
async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant<pre_nop> = {})
{
#if !defined(__gfx94__) && !defined(__gfx950__)
static_assert(always_false_v<integral_constant<unsigned, num_dwords>>,
"global_load_lds requires CDNA3+ (gfx940/gfx950). "
"Ensure kKVLoadMode is BUFFER_LOAD on this architecture.");
#endif
static_assert(num_dwords == 1 || num_dwords == 4,
"global_load_lds supports num_dwords == 1 or 4 only "
"(2 dwords does not exist on any supported arch; "
"3 dwords only on CDNA4 and unused in FMHA pipeline)");
// Inline asm: only the global address is an explicit operand. The LDS
// destination is implicit via M0 (see contract above). `"=r"(smem)` is a
// SSA scheduling anchor only — `smem` is NOT written by this asm; the
// load goes to LDS at `M0[17:2]*4 + offset:0 + ThreadID*4`.
#define CK_TILE_GLOBAL_LOAD_LDS_INSTR(instr) \
if constexpr(pre_nop) \
asm volatile("s_nop 4\n" instr " %1, off offset:0" \
: "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \
: "v"(global_addr) \
: "memory" /*prevents reorder across m0_{set,inc}*/); \
else \
asm volatile(instr " %1, off offset:0" \
: "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \
: "v"(global_addr) \
: "memory" /*prevents reorder across m0_{set,inc}*/);
if constexpr(num_dwords == 1)
{
CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dword");
}
else if constexpr(num_dwords == 4)
{
CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dwordx4");
}
#undef CK_TILE_GLOBAL_LOAD_LDS_INSTR
}
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE thread_buffer<int8_t, N>

View File

@@ -45,9 +45,29 @@ template <typename BottomTensorView_,
typename StaticValidArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1,
typename YsGatherDims = sequence<0>>
typename YsGatherDims = sequence<0>,
bool kUseGlobalLoad_ = false>
struct tile_scatter_gather
{
static constexpr bool kUseGlobalLoad = kUseGlobalLoad_;
#if !defined(__gfx94__) && !defined(__gfx950__)
// global_load_lds instruction is only available on CDNA3+ (gfx940/gfx950).
// On other architectures, kUseGlobalLoad must be false.
static_assert(!kUseGlobalLoad_,
"kUseGlobalLoad requires global_load_lds (CDNA3+: gfx940/gfx950). "
"This kernel should not be instantiated on this architecture.");
#endif
// Empty placeholder used by the SRD instantiation so physical_pages_ and
// page_stride_elements_ occupy zero bytes there (combined with
// [[no_unique_address]] on the member declarations). Access sites are all
// inside `if constexpr(kUseGlobalLoad_)` arms, which compile out in SRD
// mode, so no caller needs to change.
struct gl_field_empty_t
{
};
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
@@ -233,15 +253,22 @@ struct tile_scatter_gather
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution,
const PageIdxArray& page_idx,
const ValidArray& valids)
const ValidArray& valids,
index_t page_stride_elements = 0)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
page_idx_{page_idx},
physical_pages_{},
page_stride_elements_{},
valids_{valids},
pre_computed_coords_{}
{
if constexpr(kUseGlobalLoad_)
{
page_stride_elements_ = page_stride_elements;
}
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
@@ -357,6 +384,34 @@ struct tile_scatter_gather
bottom_tensor_view_.buf_.p_data_ = data;
}
// Override buffer size (input in RAW elements, NOT pre-divided by PackedSize) for
// SRD num_records control. Use to set max range when SRD is rebased per-tile
// (page_size >= kN0 path): each rebased SRD only needs to cover one page; without
// this the SRD claims validity for memory beyond the allocated buffer, which can
// fault on gfx950 page-table validation.
//
// Matches buffer_view ctor convention (buffer_view.hpp:245): input is raw element
// count and is divided by PackedSize before being stored. For PackedSize=1
// (fp16/bf16/fp8) the division is a no-op; for PackedSize=2 (FP4 / packed int4)
// skipping it would over-report num_records by 2x and silently mask OOB on SRD
// reads. batch_prefill currently does not exercise the packed-type path, but this
// setter is generic infrastructure (lives in tile_scatter_gather.hpp) so it must
// honor the same invariant the ctor enforces.
CK_TILE_DEVICE constexpr void set_bottom_tensor_view_buffer_size(index_t size)
{
// Hint the optimizer that size is positive without inserting a runtime
// branch. Using <cassert> assert() here corrupted gfx950 batch_prefill
// output: the __assert_fail handler's SGPR pressure forced the K-SRD
// register window to be reused as scratch and scattered the SRD writes
// across two conditional branches, which gfx950's packed
// buffer_load_dwordx4 issue window doesn't tolerate (gfx942 absorbs it
// via per-tile single-dword loads). __builtin_assume is hint-only —
// no branch, no scratch SGPRs, no codegen impact.
__builtin_assume(size > 0);
using BufType = remove_cvref_t<decltype(bottom_tensor_view_.buf_)>;
bottom_tensor_view_.buf_.buffer_size_ = size / BufType::PackedSize;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
@@ -458,7 +513,21 @@ struct tile_scatter_gather
// read from bottom tensor
const vector_t vec_value = [&]() {
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
if constexpr(kUseGlobalLoad_)
{
// Global load mode: 64-bit typed pointer arithmetic
const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_;
const auto physical_page = physical_pages_[idx_gather];
const auto coord_offset = bottom_tensor_thread_coord.get_offset();
const long_index_t total_offset =
static_cast<long_index_t>(physical_page) * page_stride_elements_ +
coord_offset + page_offset;
const auto* addr = base_ptr + total_offset;
vector_t v;
__builtin_memcpy(&v, addr, sizeof(vector_t));
return v;
}
else if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
@@ -680,7 +749,23 @@ struct tile_scatter_gather
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
if constexpr(kUseGlobalLoad_)
{
// Global load mode: global_load_lds with 64-bit address
constexpr index_t vector_size =
sizeof(vector_t) / sizeof(uint32_t); // dwords per vector
const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_;
const auto physical_page = physical_pages_[idx_gather];
const auto coord_offset = bottom_tensor_thread_coord.get_offset();
const long_index_t total_offset =
static_cast<long_index_t>(physical_page) * page_stride_elements_ +
coord_offset + page_offset;
const auto* addr = base_ptr + total_offset;
// global_load_lds takes a byte address; addr (const DataType*)
// converts implicitly to const void*, no explicit cast needed.
async_global_load_lds_dwordxn<vector_size>(smem, addr, pre_nop_);
}
else if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
@@ -1046,6 +1131,13 @@ struct tile_scatter_gather
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
CK_TILE_DEVICE void update_physical_pages(const PageIdxArray& pages)
{
static_assert(kUseGlobalLoad_,
"global-load mode only; physical_pages_ is unused in SRD mode.");
physical_pages_ = pages;
}
CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
{
if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
@@ -1139,7 +1231,29 @@ struct tile_scatter_gather
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
// Scatter/gather offsets for each element, set by update_page_idx().
// SRD mode (kUseGlobalLoad=false): buffer_load(SRD, page_idx_[i] + coord).
// page_idx_[i] = within-page offset when kPageBlockSize >= kN0 (SRD rebased to page base)
// page_idx_[i] = page_base + within-page offset when kPageBlockSize < kN0 (full voffset)
// Global load mode (kUseGlobalLoad=true): page_idx_[i] = within-page offset only.
// Full address = base + physical_pages_[i] * page_stride_elements_ + page_idx_[i] + coord
PageIdxArray page_idx_;
// Physical page indices for global load mode (kUseGlobalLoad=true only).
// Maps each gather element to its physical page in a paged memory pool.
// Updated via update_physical_pages() before each load call.
// SRD mode: collapsed to gl_field_empty_t so the storage disappears.
[[no_unique_address]] std::conditional_t<kUseGlobalLoad_, PageIdxArray, gl_field_empty_t>
physical_pages_;
// Page stride in elements for global load mode (kUseGlobalLoad=true only).
// physical_pages_[i] * page_stride_elements_ gives the page base offset in elements.
// Set at construction time via the make_tile_scatter_gather overload that
// takes bool_constant<kUseGlobalLoad>; immutable thereafter.
// SRD mode: collapsed to gl_field_empty_t so the storage disappears.
[[no_unique_address]] std::conditional_t<kUseGlobalLoad_, index_t, gl_field_empty_t>
page_stride_elements_;
ValidArray valids_;
// this contains:
@@ -1178,7 +1292,8 @@ template <typename TensorView_,
typename StaticPageIndexArray_,
index_t HsGatherDim,
index_t NumCoord,
index_t... YsGatherDims>
index_t... YsGatherDims,
bool UseGlobalLoad = false>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
@@ -1187,7 +1302,9 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
const StaticPageIndexArray_& page_idx,
number<HsGatherDim>,
number<NumCoord>,
sequence<YsGatherDims...>)
sequence<YsGatherDims...>,
bool_constant<UseGlobalLoad> = {},
index_t page_stride_elements = 0)
{
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
@@ -1196,11 +1313,17 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
std::nullptr_t,
HsGatherDim,
NumCoord,
sequence<YsGatherDims...>>{
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
sequence<YsGatherDims...>,
UseGlobalLoad>{tensor_view,
window_lengths,
origin,
tile_distribution,
page_idx,
nullptr,
page_stride_elements};
}
// Legacy overload (compatible with original API)
// Legacy overload (compatible with original API, kUseGlobalLoad=false)
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
@@ -1227,6 +1350,42 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
}
// Overload with kUseGlobalLoad (simple, used by K cache).
// page_stride_elements is forwarded to the constructor; required (non-zero)
// when UseGlobalLoad=true so that physical_pages_[i] * page_stride_elements_
// produces a valid address. Defaulting to 0 keeps SRD-mode call sites unchanged
// (page_stride_elements_ is unread in SRD mode).
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
bool UseGlobalLoad>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
bool_constant<UseGlobalLoad>,
index_t page_stride_elements = 0)
{
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
std::nullptr_t,
0,
1,
sequence<0>,
UseGlobalLoad>{tensor_view,
window_lengths,
origin,
tile_distribution,
page_idx,
nullptr,
page_stride_elements};
}
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,

View File

@@ -12,6 +12,20 @@
namespace ck_tile {
// `always_false_v<T...>` — a value-template that is always `false` but whose
// evaluation is deferred until template instantiation. The canonical use is
// inside the `else` arm of an `if constexpr` chain or under an arch-gated
// `#if` to fire a `static_assert` ONLY when the offending instantiation is
// actually requested, e.g.:
//
// if constexpr (...) { ... }
// else { static_assert(always_false_v<T>, "unsupported T"); }
//
// A bare `static_assert(false, ...)` would fire at template-definition
// parse time on conforming compilers, breaking the whole TU.
template <typename...>
inline constexpr bool always_false_v = false;
// remove_cvref_t
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;