diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index a523acd291..cab9ee5944 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -1250,6 +1250,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ? kargs.hdim_v : kargs.stride_v; + // Last valid index into this batch's page table; load_physical_pages clamps + // page-table reads to [0, max_page_table_idx] to prevent OOB into the next + // batch's pages. Empty batch (seqlen_k == 0) clamps to 0. + const index_t max_page_table_idx = + kargs.seqlen_k > 0 ? (kargs.seqlen_k - 1) / kPageBlockSize : 0; + auto o_acc_tile = [&] { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { @@ -1296,7 +1302,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.batch_stride_k, kargs.batch_stride_v, dropout, - sink_value); + sink_value, + max_page_table_idx); } else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { @@ -1326,6 +1333,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.batch_stride_v, dropout, sink_value, + max_page_table_idx, k_descale_ptr, v_descale_ptr, kargs.nblock_stride_kv_block_descale, @@ -1352,7 +1360,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.batch_stride_k, kargs.batch_stride_v, dropout, - sink_value); + sink_value, + max_page_table_idx); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 8aa6d17dc3..adc24943e6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -35,7 +35,8 @@ template {}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; - physical_pages[k0] = page_idx[page_id]; + const index_t page_id = + ck_tile::min(global_token_idx >> kLog2PageSize, max_page_table_idx); + physical_pages[k0] = page_idx[page_id]; }); } else @@ -75,7 +77,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - physical_pages[k0] = page_idx[global_token_idx]; + physical_pages[k0] = page_idx[ck_tile::min(global_token_idx, max_page_table_idx)]; }); } else if constexpr(kVTileCrossesPages) @@ -85,8 +87,9 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; - physical_pages[k0] = page_idx[page_id]; + const index_t page_id = + ck_tile::min(global_token_idx >> kLog2PageSize, max_page_table_idx); + physical_pages[k0] = page_idx[page_id]; }); } else @@ -94,7 +97,8 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, // V tile fully contained in one page: lane0 lookup, broadcast to all const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start); const index_t lane0_page_id = - (global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize; + ck_tile::min((global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize, + max_page_table_idx); const index_t shared_physical_page = page_idx[lane0_page_id]; static_for<0, kLoopCount, 1>{}( @@ -427,6 +431,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t page_stride_v, DropoutType& dropout, const float sink_v, + const index_t max_page_table_idx, // KV_BLOCKSCALE parameters (only used when QScaleEnum == KV_BLOCKSCALE) const float* k_descale_ptr = nullptr, const float* v_descale_ptr = nullptr, @@ -611,7 +616,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kN0 / NRepeat, kKVMemoryLayout, true, - kN0>(page_idx, k_coord, current_seq_k, k_physical_pages); + kN0>( + page_idx, k_coord, current_seq_k, k_physical_pages, max_page_table_idx); kv_offset_array_transform, decltype(k_coord), @@ -839,7 +845,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync 1, kKVMemoryLayout, false, - kN0>(page_idx, v_coord, current_seq_k, v_physical_pages_k2); + kN0>( + page_idx, v_coord, current_seq_k, v_physical_pages_k2, max_page_table_idx); // Copy to merged array static_for<0, V_KIterInner, 1>{}([&](auto k1) { @@ -859,7 +866,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync 1, kKVMemoryLayout, false, - kN0>(page_idx, v_coord, current_seq_k, v_physical_pages); + kN0>( + page_idx, v_coord, current_seq_k, v_physical_pages, max_page_table_idx); } }; @@ -1516,7 +1524,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kN0 / NRepeat, kKVMemoryLayout, true, - kN0>(page_idx, k_coord, current_seq_k, k_physical_pages); + kN0>( + page_idx, k_coord, current_seq_k, k_physical_pages, max_page_table_idx); kv_offset_array_transform, decltype(k_coord), @@ -1672,7 +1681,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t page_stride_k, const index_t page_stride_v, DropoutType& dropout, - float sink_v) const + float sink_v, + const index_t max_page_table_idx) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -1701,7 +1711,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync page_stride_k, page_stride_v, dropout, - sink_v); + sink_v, + max_page_table_idx); } // Overload for KV_BLOCKSCALE: K/V descale is per-page @@ -1736,6 +1747,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t page_stride_v, DropoutType& dropout, float sink_v, + const index_t max_page_table_idx, const float* k_descale_ptr, const float* v_descale_ptr, index_t nblock_stride_kv_block_descale, @@ -1769,6 +1781,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync page_stride_v, dropout, sink_v, + max_page_table_idx, k_descale_ptr, v_descale_ptr, nblock_stride_kv_block_descale,