[CK] Fix OOB page table read in batch_prefill V prefetch (AICK-1171) (#6932)

## Summary

Fix a GPU memory access fault in `mha_batch_prefill` triggered when the
per-batch page table is tightly sized (no trailing slack).

**Affected configurations:**
- All FMHA batch prefill V2 kernels
(`block_fmha_batch_prefill_pipeline_qr_ks_vs_async`)
- Triggered by paged KV layouts where `kv_page_indices.numel() ==
ceil(seqlen_k / page_size)` exactly
- Manifests as: `Memory access fault by GPU node-X (Agent handle:
0x...)` followed by `Aborted (core dumped)`
- Silent corruption (no fault, wrong output) when the OOB read happens
to land in zero-initialized memory

### Root cause

`load_physical_pages` performs **lookahead reads** on the page table to
prefetch K/V tiles for the next iteration. When the page table for a
batch has exactly `N` entries, the V-tile prefetch indexes `page_idx[N]`
(one past the last valid entry), reading either uninitialized memory or
the next batch's slot. On gfx942 with a tightly-sized page table, the
read crosses into an unmapped page and triggers an HSA page fault.

The bug was masked in earlier testing because most test harnesses pad
`kv_page_indices` with trailing zeros — OOB reads then return `page_id =
0`, a valid in-cache page, producing silent numerical drift instead of a
fault.

### Fix design

Thread `max_page_table_idx = (seqlen_k - 1) / page_size` from the kernel
layer down to `load_physical_pages`, and clamp every page-table read
with `ck_tile::min()`. Applied to **all four code paths** in the V
prefetch:

| Branch | What it does | Clamp applied |
|--------|-------------|---------------|
| `kIsKcache` | K prefetch loop | `min(global_token_idx >>
kLog2PageSize, max_page_table_idx)` |
| V LINEAR (`page_size == 1`) | One token = one page |
`min(global_token_idx, max_page_table_idx)` |
| V crosses pages (`kVTileCrossesPages`) | Per-thread page lookup |
`min(global_token_idx >> kLog2PageSize, max_page_table_idx)` |
| V single page (lane0 broadcast) | `readfirstlane`-uniform lookup |
`min(... >> kLog2PageSize, max_page_table_idx)` |

### Key design decisions

**Mandatory parameter, not optional with a sentinel default.** An
optional `max_page_table_idx = INT32_MAX` default would let the bug
silently come back at any new callsite that forgets to pass it. Making
it mandatory forces every caller to opt in explicitly and surfaces
missed callsites at compile time.

**`seqlen_k == 0` clamps to 0** instead of underflowing `(0 - 1) /
page_size` to `-1`. The empty-batch case is rare but well-defined: clamp
every read to slot 0.

**Single computation in the kernel layer.**
`FmhaBatchPrefillWithPagedKVCacheKernel` computes `max_page_table_idx`
once per batch and forwards it through every QScale branch (PERTENSOR /
KV_BLOCKSCALE / default). All three `operator()` overloads of the
pipeline (rich, default forwarder, KV_BLOCKSCALE forwarder) take and
forward the parameter.

### Files changed

| File | Change |
|------|--------|
| `include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp` |
Compute `max_page_table_idx` per batch, forward to all 3 QScale branches
|
|
`include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp`
| Add `max_page_table_idx` to `load_physical_pages` and 3 `operator()`
overloads; clamp page-id reads in 4 code paths |

## Test plan

- [x] AICK-1171 reproducer verified on MI-308X (gfx942)
- [x] New pytest case `test_batch_prefill_aick1171_oob_page_table_read`
in aiter, parametrized over `total_blocks ∈ {160, 164, 168, 176, 208,
256}` (matches the `crash1_r8_*` bisect family)
- [x] Full FMHA batch prefill suite on gfx942 + gfx950

## Linked issue

AICK-1171.
This commit is contained in:
Jeff Huang
2026-05-05 14:28:19 +08:00
committed by GitHub
parent d7d7905980
commit 537a9e7489
2 changed files with 37 additions and 15 deletions

View File

@@ -1250,6 +1250,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
? kargs.hdim_v
: kargs.stride_v;
// Last valid index into this batch's page table; load_physical_pages clamps
// page-table reads to [0, max_page_table_idx] to prevent OOB into the next
// batch's pages. Empty batch (seqlen_k == 0) clamps to 0.
const index_t max_page_table_idx =
kargs.seqlen_k > 0 ? (kargs.seqlen_k - 1) / kPageBlockSize : 0;
auto o_acc_tile = [&] {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
@@ -1296,7 +1302,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout,
sink_value);
sink_value,
max_page_table_idx);
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
@@ -1326,6 +1333,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.batch_stride_v,
dropout,
sink_value,
max_page_table_idx,
k_descale_ptr,
v_descale_ptr,
kargs.nblock_stride_kv_block_descale,
@@ -1352,7 +1360,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout,
sink_value);
sink_value,
max_page_table_idx);
}
}();

View File

@@ -35,7 +35,8 @@ template <typename IndexArrayType,
CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
const CoordVecType& coord_vec,
index_t global_seq_offset,
IndexArrayType& physical_pages)
IndexArrayType& physical_pages,
index_t max_page_table_idx)
{
static constexpr index_t kLog2PageSize = [] {
index_t shift = 0;
@@ -56,8 +57,9 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t page_id = global_token_idx >> kLog2PageSize;
physical_pages[k0] = page_idx[page_id];
const index_t page_id =
ck_tile::min(global_token_idx >> kLog2PageSize, max_page_table_idx);
physical_pages[k0] = page_idx[page_id];
});
}
else
@@ -75,7 +77,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
physical_pages[k0] = page_idx[global_token_idx];
physical_pages[k0] = page_idx[ck_tile::min(global_token_idx, max_page_table_idx)];
});
}
else if constexpr(kVTileCrossesPages)
@@ -85,8 +87,9 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t page_id = global_token_idx >> kLog2PageSize;
physical_pages[k0] = page_idx[page_id];
const index_t page_id =
ck_tile::min(global_token_idx >> kLog2PageSize, max_page_table_idx);
physical_pages[k0] = page_idx[page_id];
});
}
else
@@ -94,7 +97,8 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
// V tile fully contained in one page: lane0 lookup, broadcast to all
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
const index_t lane0_page_id =
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
ck_tile::min((global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize,
max_page_table_idx);
const index_t shared_physical_page = page_idx[lane0_page_id];
static_for<0, kLoopCount, 1>{}(
@@ -427,6 +431,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t page_stride_v,
DropoutType& dropout,
const float sink_v,
const index_t max_page_table_idx,
// KV_BLOCKSCALE parameters (only used when QScaleEnum == KV_BLOCKSCALE)
const float* k_descale_ptr = nullptr,
const float* v_descale_ptr = nullptr,
@@ -611,7 +616,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kN0 / NRepeat,
kKVMemoryLayout,
true,
kN0>(page_idx, k_coord, current_seq_k, k_physical_pages);
kN0>(
page_idx, k_coord, current_seq_k, k_physical_pages, max_page_table_idx);
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
@@ -839,7 +845,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
1,
kKVMemoryLayout,
false,
kN0>(page_idx, v_coord, current_seq_k, v_physical_pages_k2);
kN0>(
page_idx, v_coord, current_seq_k, v_physical_pages_k2, max_page_table_idx);
// Copy to merged array
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
@@ -859,7 +866,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
1,
kKVMemoryLayout,
false,
kN0>(page_idx, v_coord, current_seq_k, v_physical_pages);
kN0>(
page_idx, v_coord, current_seq_k, v_physical_pages, max_page_table_idx);
}
};
@@ -1516,7 +1524,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kN0 / NRepeat,
kKVMemoryLayout,
true,
kN0>(page_idx, k_coord, current_seq_k, k_physical_pages);
kN0>(
page_idx, k_coord, current_seq_k, k_physical_pages, max_page_table_idx);
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
@@ -1672,7 +1681,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t page_stride_k,
const index_t page_stride_v,
DropoutType& dropout,
float sink_v) const
float sink_v,
const index_t max_page_table_idx) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -1701,7 +1711,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
page_stride_k,
page_stride_v,
dropout,
sink_v);
sink_v,
max_page_table_idx);
}
// Overload for KV_BLOCKSCALE: K/V descale is per-page
@@ -1736,6 +1747,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t page_stride_v,
DropoutType& dropout,
float sink_v,
const index_t max_page_table_idx,
const float* k_descale_ptr,
const float* v_descale_ptr,
index_t nblock_stride_kv_block_descale,
@@ -1769,6 +1781,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
page_stride_v,
dropout,
sink_v,
max_page_table_idx,
k_descale_ptr,
v_descale_ptr,
nblock_stride_kv_block_descale,