mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
[CK] Fix OOB page table read in batch_prefill V prefetch (AICK-1171) (#6932)
## Summary
Fix a GPU memory access fault in `mha_batch_prefill` triggered when the
per-batch page table is tightly sized (no trailing slack).
**Affected configurations:**
- All FMHA batch prefill V2 kernels
(`block_fmha_batch_prefill_pipeline_qr_ks_vs_async`)
- Triggered by paged KV layouts where `kv_page_indices.numel() ==
ceil(seqlen_k / page_size)` exactly
- Manifests as: `Memory access fault by GPU node-X (Agent handle:
0x...)` followed by `Aborted (core dumped)`
- Silent corruption (no fault, wrong output) when the OOB read happens
to land in zero-initialized memory
### Root cause
`load_physical_pages` performs **lookahead reads** on the page table to
prefetch K/V tiles for the next iteration. When the page table for a
batch has exactly `N` entries, the V-tile prefetch indexes `page_idx[N]`
(one past the last valid entry), reading either uninitialized memory or
the next batch's slot. On gfx942 with a tightly-sized page table, the
read crosses into an unmapped page and triggers an HSA page fault.
The bug was masked in earlier testing because most test harnesses pad
`kv_page_indices` with trailing zeros — OOB reads then return `page_id =
0`, a valid in-cache page, producing silent numerical drift instead of a
fault.
### Fix design
Thread `max_page_table_idx = (seqlen_k - 1) / page_size` from the kernel
layer down to `load_physical_pages`, and clamp every page-table read
with `ck_tile::min()`. Applied to **all four code paths** in the V
prefetch:
| Branch | What it does | Clamp applied |
|--------|-------------|---------------|
| `kIsKcache` | K prefetch loop | `min(global_token_idx >>
kLog2PageSize, max_page_table_idx)` |
| V LINEAR (`page_size == 1`) | One token = one page |
`min(global_token_idx, max_page_table_idx)` |
| V crosses pages (`kVTileCrossesPages`) | Per-thread page lookup |
`min(global_token_idx >> kLog2PageSize, max_page_table_idx)` |
| V single page (lane0 broadcast) | `readfirstlane`-uniform lookup |
`min(... >> kLog2PageSize, max_page_table_idx)` |
### Key design decisions
**Mandatory parameter, not optional with a sentinel default.** An
optional `max_page_table_idx = INT32_MAX` default would let the bug
silently come back at any new callsite that forgets to pass it. Making
it mandatory forces every caller to opt in explicitly and surfaces
missed callsites at compile time.
**`seqlen_k == 0` clamps to 0** instead of underflowing `(0 - 1) /
page_size` to `-1`. The empty-batch case is rare but well-defined: clamp
every read to slot 0.
**Single computation in the kernel layer.**
`FmhaBatchPrefillWithPagedKVCacheKernel` computes `max_page_table_idx`
once per batch and forwards it through every QScale branch (PERTENSOR /
KV_BLOCKSCALE / default). All three `operator()` overloads of the
pipeline (rich, default forwarder, KV_BLOCKSCALE forwarder) take and
forward the parameter.
### Files changed
| File | Change |
|------|--------|
| `include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp` |
Compute `max_page_table_idx` per batch, forward to all 3 QScale branches
|
|
`include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp`
| Add `max_page_table_idx` to `load_physical_pages` and 3 `operator()`
overloads; clamp page-id reads in 4 code paths |
## Test plan
- [x] AICK-1171 reproducer verified on MI-308X (gfx942)
- [x] New pytest case `test_batch_prefill_aick1171_oob_page_table_read`
in aiter, parametrized over `total_blocks ∈ {160, 164, 168, 176, 208,
256}` (matches the `crash1_r8_*` bisect family)
- [x] Full FMHA batch prefill suite on gfx942 + gfx950
## Linked issue
AICK-1171.
This commit is contained in:
@@ -1250,6 +1250,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
? kargs.hdim_v
|
||||
: kargs.stride_v;
|
||||
|
||||
// Last valid index into this batch's page table; load_physical_pages clamps
|
||||
// page-table reads to [0, max_page_table_idx] to prevent OOB into the next
|
||||
// batch's pages. Empty batch (seqlen_k == 0) clamps to 0.
|
||||
const index_t max_page_table_idx =
|
||||
kargs.seqlen_k > 0 ? (kargs.seqlen_k - 1) / kPageBlockSize : 0;
|
||||
|
||||
auto o_acc_tile = [&] {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
@@ -1296,7 +1302,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout,
|
||||
sink_value);
|
||||
sink_value,
|
||||
max_page_table_idx);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
@@ -1326,6 +1333,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.batch_stride_v,
|
||||
dropout,
|
||||
sink_value,
|
||||
max_page_table_idx,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
kargs.nblock_stride_kv_block_descale,
|
||||
@@ -1352,7 +1360,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout,
|
||||
sink_value);
|
||||
sink_value,
|
||||
max_page_table_idx);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -35,7 +35,8 @@ template <typename IndexArrayType,
|
||||
CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
|
||||
const CoordVecType& coord_vec,
|
||||
index_t global_seq_offset,
|
||||
IndexArrayType& physical_pages)
|
||||
IndexArrayType& physical_pages,
|
||||
index_t max_page_table_idx)
|
||||
{
|
||||
static constexpr index_t kLog2PageSize = [] {
|
||||
index_t shift = 0;
|
||||
@@ -56,8 +57,9 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
physical_pages[k0] = page_idx[page_id];
|
||||
const index_t page_id =
|
||||
ck_tile::min(global_token_idx >> kLog2PageSize, max_page_table_idx);
|
||||
physical_pages[k0] = page_idx[page_id];
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -75,7 +77,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
physical_pages[k0] = page_idx[global_token_idx];
|
||||
physical_pages[k0] = page_idx[ck_tile::min(global_token_idx, max_page_table_idx)];
|
||||
});
|
||||
}
|
||||
else if constexpr(kVTileCrossesPages)
|
||||
@@ -85,8 +87,9 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
physical_pages[k0] = page_idx[page_id];
|
||||
const index_t page_id =
|
||||
ck_tile::min(global_token_idx >> kLog2PageSize, max_page_table_idx);
|
||||
physical_pages[k0] = page_idx[page_id];
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -94,7 +97,8 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
|
||||
// V tile fully contained in one page: lane0 lookup, broadcast to all
|
||||
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
|
||||
const index_t lane0_page_id =
|
||||
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
|
||||
ck_tile::min((global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize,
|
||||
max_page_table_idx);
|
||||
const index_t shared_physical_page = page_idx[lane0_page_id];
|
||||
|
||||
static_for<0, kLoopCount, 1>{}(
|
||||
@@ -427,6 +431,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout,
|
||||
const float sink_v,
|
||||
const index_t max_page_table_idx,
|
||||
// KV_BLOCKSCALE parameters (only used when QScaleEnum == KV_BLOCKSCALE)
|
||||
const float* k_descale_ptr = nullptr,
|
||||
const float* v_descale_ptr = nullptr,
|
||||
@@ -611,7 +616,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kN0 / NRepeat,
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kN0>(page_idx, k_coord, current_seq_k, k_physical_pages);
|
||||
kN0>(
|
||||
page_idx, k_coord, current_seq_k, k_physical_pages, max_page_table_idx);
|
||||
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
@@ -839,7 +845,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0>(page_idx, v_coord, current_seq_k, v_physical_pages_k2);
|
||||
kN0>(
|
||||
page_idx, v_coord, current_seq_k, v_physical_pages_k2, max_page_table_idx);
|
||||
|
||||
// Copy to merged array
|
||||
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
|
||||
@@ -859,7 +866,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0>(page_idx, v_coord, current_seq_k, v_physical_pages);
|
||||
kN0>(
|
||||
page_idx, v_coord, current_seq_k, v_physical_pages, max_page_table_idx);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1516,7 +1524,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kN0 / NRepeat,
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kN0>(page_idx, k_coord, current_seq_k, k_physical_pages);
|
||||
kN0>(
|
||||
page_idx, k_coord, current_seq_k, k_physical_pages, max_page_table_idx);
|
||||
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
@@ -1672,7 +1681,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout,
|
||||
float sink_v) const
|
||||
float sink_v,
|
||||
const index_t max_page_table_idx) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -1701,7 +1711,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
page_stride_k,
|
||||
page_stride_v,
|
||||
dropout,
|
||||
sink_v);
|
||||
sink_v,
|
||||
max_page_table_idx);
|
||||
}
|
||||
|
||||
// Overload for KV_BLOCKSCALE: K/V descale is per-page
|
||||
@@ -1736,6 +1747,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout,
|
||||
float sink_v,
|
||||
const index_t max_page_table_idx,
|
||||
const float* k_descale_ptr,
|
||||
const float* v_descale_ptr,
|
||||
index_t nblock_stride_kv_block_descale,
|
||||
@@ -1769,6 +1781,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
page_stride_v,
|
||||
dropout,
|
||||
sink_v,
|
||||
max_page_table_idx,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
nblock_stride_kv_block_descale,
|
||||
|
||||
Reference in New Issue
Block a user