From 537a9e7489e4274c433e084c1a22718449490621 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 5 May 2026 14:28:19 +0800 Subject: [PATCH] [CK] Fix OOB page table read in batch_prefill V prefetch (AICK-1171) (#6932) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Fix a GPU memory access fault in `mha_batch_prefill` triggered when the per-batch page table is tightly sized (no trailing slack). **Affected configurations:** - All FMHA batch prefill V2 kernels (`block_fmha_batch_prefill_pipeline_qr_ks_vs_async`) - Triggered by paged KV layouts where `kv_page_indices.numel() == ceil(seqlen_k / page_size)` exactly - Manifests as: `Memory access fault by GPU node-X (Agent handle: 0x...)` followed by `Aborted (core dumped)` - Silent corruption (no fault, wrong output) when the OOB read happens to land in zero-initialized memory ### Root cause `load_physical_pages` performs **lookahead reads** on the page table to prefetch K/V tiles for the next iteration. When the page table for a batch has exactly `N` entries, the V-tile prefetch indexes `page_idx[N]` (one past the last valid entry), reading either uninitialized memory or the next batch's slot. On gfx942 with a tightly-sized page table, the read crosses into an unmapped page and triggers an HSA page fault. The bug was masked in earlier testing because most test harnesses pad `kv_page_indices` with trailing zeros — OOB reads then return `page_id = 0`, a valid in-cache page, producing silent numerical drift instead of a fault. ### Fix design Thread `max_page_table_idx = (seqlen_k - 1) / page_size` from the kernel layer down to `load_physical_pages`, and clamp every page-table read with `ck_tile::min()`. Applied to **all four code paths** in the V prefetch: | Branch | What it does | Clamp applied | |--------|-------------|---------------| | `kIsKcache` | K prefetch loop | `min(global_token_idx >> kLog2PageSize, max_page_table_idx)` | | V LINEAR (`page_size == 1`) | One token = one page | `min(global_token_idx, max_page_table_idx)` | | V crosses pages (`kVTileCrossesPages`) | Per-thread page lookup | `min(global_token_idx >> kLog2PageSize, max_page_table_idx)` | | V single page (lane0 broadcast) | `readfirstlane`-uniform lookup | `min(... >> kLog2PageSize, max_page_table_idx)` | ### Key design decisions **Mandatory parameter, not optional with a sentinel default.** An optional `max_page_table_idx = INT32_MAX` default would let the bug silently come back at any new callsite that forgets to pass it. Making it mandatory forces every caller to opt in explicitly and surfaces missed callsites at compile time. **`seqlen_k == 0` clamps to 0** instead of underflowing `(0 - 1) / page_size` to `-1`. The empty-batch case is rare but well-defined: clamp every read to slot 0. **Single computation in the kernel layer.** `FmhaBatchPrefillWithPagedKVCacheKernel` computes `max_page_table_idx` once per batch and forwards it through every QScale branch (PERTENSOR / KV_BLOCKSCALE / default). All three `operator()` overloads of the pipeline (rich, default forwarder, KV_BLOCKSCALE forwarder) take and forward the parameter. ### Files changed | File | Change | |------|--------| | `include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp` | Compute `max_page_table_idx` per batch, forward to all 3 QScale branches | | `include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp` | Add `max_page_table_idx` to `load_physical_pages` and 3 `operator()` overloads; clamp page-id reads in 4 code paths | ## Test plan - [x] AICK-1171 reproducer verified on MI-308X (gfx942) - [x] New pytest case `test_batch_prefill_aick1171_oob_page_table_read` in aiter, parametrized over `total_blocks ∈ {160, 164, 168, 176, 208, 256}` (matches the `crash1_r8_*` bisect family) - [x] Full FMHA batch prefill suite on gfx942 + gfx950 ## Linked issue AICK-1171. --- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 13 ++++++- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 39 ++++++++++++------- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index a523acd291..cab9ee5944 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -1250,6 +1250,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ? kargs.hdim_v : kargs.stride_v; + // Last valid index into this batch's page table; load_physical_pages clamps + // page-table reads to [0, max_page_table_idx] to prevent OOB into the next + // batch's pages. Empty batch (seqlen_k == 0) clamps to 0. + const index_t max_page_table_idx = + kargs.seqlen_k > 0 ? (kargs.seqlen_k - 1) / kPageBlockSize : 0; + auto o_acc_tile = [&] { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { @@ -1296,7 +1302,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.batch_stride_k, kargs.batch_stride_v, dropout, - sink_value); + sink_value, + max_page_table_idx); } else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { @@ -1326,6 +1333,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.batch_stride_v, dropout, sink_value, + max_page_table_idx, k_descale_ptr, v_descale_ptr, kargs.nblock_stride_kv_block_descale, @@ -1352,7 +1360,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.batch_stride_k, kargs.batch_stride_v, dropout, - sink_value); + sink_value, + max_page_table_idx); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 8aa6d17dc3..adc24943e6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -35,7 +35,8 @@ template {}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; - physical_pages[k0] = page_idx[page_id]; + const index_t page_id = + ck_tile::min(global_token_idx >> kLog2PageSize, max_page_table_idx); + physical_pages[k0] = page_idx[page_id]; }); } else @@ -75,7 +77,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - physical_pages[k0] = page_idx[global_token_idx]; + physical_pages[k0] = page_idx[ck_tile::min(global_token_idx, max_page_table_idx)]; }); } else if constexpr(kVTileCrossesPages) @@ -85,8 +87,9 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; - physical_pages[k0] = page_idx[page_id]; + const index_t page_id = + ck_tile::min(global_token_idx >> kLog2PageSize, max_page_table_idx); + physical_pages[k0] = page_idx[page_id]; }); } else @@ -94,7 +97,8 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, // V tile fully contained in one page: lane0 lookup, broadcast to all const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start); const index_t lane0_page_id = - (global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize; + ck_tile::min((global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize, + max_page_table_idx); const index_t shared_physical_page = page_idx[lane0_page_id]; static_for<0, kLoopCount, 1>{}( @@ -427,6 +431,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t page_stride_v, DropoutType& dropout, const float sink_v, + const index_t max_page_table_idx, // KV_BLOCKSCALE parameters (only used when QScaleEnum == KV_BLOCKSCALE) const float* k_descale_ptr = nullptr, const float* v_descale_ptr = nullptr, @@ -611,7 +616,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kN0 / NRepeat, kKVMemoryLayout, true, - kN0>(page_idx, k_coord, current_seq_k, k_physical_pages); + kN0>( + page_idx, k_coord, current_seq_k, k_physical_pages, max_page_table_idx); kv_offset_array_transform, decltype(k_coord), @@ -839,7 +845,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync 1, kKVMemoryLayout, false, - kN0>(page_idx, v_coord, current_seq_k, v_physical_pages_k2); + kN0>( + page_idx, v_coord, current_seq_k, v_physical_pages_k2, max_page_table_idx); // Copy to merged array static_for<0, V_KIterInner, 1>{}([&](auto k1) { @@ -859,7 +866,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync 1, kKVMemoryLayout, false, - kN0>(page_idx, v_coord, current_seq_k, v_physical_pages); + kN0>( + page_idx, v_coord, current_seq_k, v_physical_pages, max_page_table_idx); } }; @@ -1516,7 +1524,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kN0 / NRepeat, kKVMemoryLayout, true, - kN0>(page_idx, k_coord, current_seq_k, k_physical_pages); + kN0>( + page_idx, k_coord, current_seq_k, k_physical_pages, max_page_table_idx); kv_offset_array_transform, decltype(k_coord), @@ -1672,7 +1681,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t page_stride_k, const index_t page_stride_v, DropoutType& dropout, - float sink_v) const + float sink_v, + const index_t max_page_table_idx) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -1701,7 +1711,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync page_stride_k, page_stride_v, dropout, - sink_v); + sink_v, + max_page_table_idx); } // Overload for KV_BLOCKSCALE: K/V descale is per-page @@ -1736,6 +1747,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t page_stride_v, DropoutType& dropout, float sink_v, + const index_t max_page_table_idx, const float* k_descale_ptr, const float* v_descale_ptr, index_t nblock_stride_kv_block_descale, @@ -1769,6 +1781,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync page_stride_v, dropout, sink_v, + max_page_table_idx, k_descale_ptr, v_descale_ptr, nblock_stride_kv_block_descale,