From 5d5e8a1f8a2e7cd5893059c17c9931babc67457f Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Thu, 5 Mar 2026 09:08:01 +0800 Subject: [PATCH] [CK] Fix 32-bit overflow in batch prefill kernel for >4GB KV cache (#4999) Use SRD rebasing for page_block_size >= kN0: move SRD base pointer to page start via 48-bit arithmetic, encode only within-page offset in voffset. Original code path preserved for ps1/ps16 via constexpr-if. ## Motivation ## Technical Details ## Test Plan ## Test Result ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 28 +++++ ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 118 ++++++++++++++---- 2 files changed, 120 insertions(+), 26 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index c6628f66be..53934ebcd3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -484,6 +484,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.init_logits_soft_cap(logits_soft_cap); } + // Check that the maximum offset won't overflow. + if constexpr(kPageBlockSize < FmhaPipeline::kN0) + { + if(num_total_pages > 1) + { + assert(static_cast(num_total_pages - 1) * batch_stride_k <= + static_cast(std::numeric_limits::max()) && + "KV cache K offset overflow: exceed int32 max"); + assert(static_cast(num_total_pages - 1) * batch_stride_v <= + static_cast(std::numeric_limits::max()) && + "KV cache V offset overflow: exceed int32 max"); + } + } + return kargs; } @@ -637,6 +651,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.init_logits_soft_cap(logits_soft_cap); } + // Check that the maximum offset won't overflow. + if constexpr(kPageBlockSize < FmhaPipeline::kN0) + { + if(num_total_pages > 1) + { + assert(static_cast(num_total_pages - 1) * batch_stride_k <= + static_cast(std::numeric_limits::max()) && + "KV cache K offset overflow: exceed int32 max"); + assert(static_cast(num_total_pages - 1) * batch_stride_v <= + static_cast(std::numeric_limits::max()) && + "KV cache V offset overflow: exceed int32 max"); + } + } + return kargs; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 1bc84836d3..a8b94b6e41 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -160,18 +160,27 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica { // K cache: per-token lookup // Each token may be on a different page, so we use physical_pages[k0] for each. - // Offset = physical_page * stride_page_block + token_idx_in_page * stride_token static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - const index_t physical_page = physical_pages[k0]; - kv_offset_vec[k0] = static_cast(physical_page) * stride_page_block + - static_cast(token_idx_in_page) * stride_token; + if constexpr(kPageBlockSize >= kN0) + { + // SRD rebasing mode: within-page offset only. + // The full page base is handled by rebasing the SRD pointer. + kv_offset_vec[k0] = token_idx_in_page * stride_token; + } + else + { + // Full global offset (original code path for ps1, ps16, etc.) + const index_t physical_page = physical_pages[k0]; + kv_offset_vec[k0] = + physical_page * stride_page_block + token_idx_in_page * stride_token; + } }); } - else // !kVTileCrossesPages + else // V cache { // V cache: use physical_pages[k0] for each token // physical_pages was already populated correctly by load_physical_pages(), handling: @@ -182,31 +191,43 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - const index_t physical_page = physical_pages[k0]; - const long_index_t page_base_offset = - static_cast(physical_page) * stride_page_block; - - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + if constexpr(kPageBlockSize >= kN0) { - // Vectorized layout offset calculation: - // Layout: [page, token_in_page/kVectorSize, head_dim, kVectorSize] - // Offset = page_base + (token/kVectorSize) * (head_dim * kVectorSize) + - // (token % kVectorSize) - const long_index_t token_offset = - static_cast((token_idx_in_page / kVectorSize) * - (stride_token * kVectorSize)) + - (token_idx_in_page % kVectorSize); - - kv_offset_vec[k0] = page_base_offset + token_offset; + // SRD rebasing mode: within-page offset only. + // The full page base is handled by rebasing the SRD pointer. + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + const index_t token_offset = + (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + + (token_idx_in_page % kVectorSize); + kv_offset_vec[k0] = token_offset; + } + else + { + kv_offset_vec[k0] = token_idx_in_page * stride_token; + } } - else // LINEAR_LAYOUT + else { - // Linear layout: [page, token_in_page, head_dim] - // Offset = page_base + token_idx_in_page * stride_token - kv_offset_vec[k0] = - page_base_offset + static_cast(token_idx_in_page) * stride_token; + // Full global offset (original code path for ps1, ps16, etc.) + const index_t physical_page = physical_pages[k0]; + const long_index_t page_base_offset = + static_cast(physical_page) * stride_page_block; + + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + const index_t token_offset = + (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + + (token_idx_in_page % kVectorSize); + kv_offset_vec[k0] = page_base_offset + token_offset; + } + else + { + kv_offset_vec[k0] = page_base_offset + token_idx_in_page * stride_token; + } } }); } @@ -561,6 +582,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto k_coord = k_dist.calculate_index(); using KDstrEncode = typename decltype(k_dist)::DstrEncode; constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0]; + // kPageBlockSize >= kN0: within-page offset only (SRD rebased per page via rebase_k_window) + // kPageBlockSize < kN0: global offset, must fit int32 statically_indexed_array k_offsets; index_t current_seq_k = seqlen_k_start; @@ -597,6 +620,40 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync k_dist, k_offsets); // K DRAM tile window for k_dram_window.init_raw(); + + // SRD rebasing: move the buffer descriptor base pointer to each page's start address + // using 48-bit pointer arithmetic, so voffset only needs the small within-page offset. + // Only applies when kPageBlockSize >= kN0 (all threads in a wave access the same page). + auto rebase_k_window = [&](auto& window, index_t physical_page) { + if constexpr(kPageBlockSize >= kN0) + { + // readfirstlane: make physical_page provably wave-uniform so the + // resulting SRD lands in SGPRs (required by buffer load instructions). + physical_page = __builtin_amdgcn_readfirstlane(physical_page); + const auto* base_ptr = k_dram_block_window.get_bottom_tensor_view().buf_.p_data_; + const auto* page_ptr = + base_ptr + static_cast(physical_page) * page_stride_k; + window.set_bottom_tensor_view_data_ptr(page_ptr); + window.init_raw(); + } + }; + + auto rebase_v_window = [&](auto& window, index_t physical_page) { + if constexpr(kPageBlockSize >= kN0) + { + physical_page = __builtin_amdgcn_readfirstlane(physical_page); + const auto* base_ptr = + v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_; + const auto* page_ptr = + base_ptr + static_cast(physical_page) * page_stride_v; + window.set_bottom_tensor_view_data_ptr(page_ptr); + window.init_raw(); + } + }; + + // Initial K SRD rebase + rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); + constexpr auto k_oob_ck = bool_constant{}; constexpr auto k_pre_np = [&]() { if constexpr(kPadSeqLenK && @@ -727,6 +784,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static_assert(decltype(VPageIndexYDims)::at(0) < VDstrEncode::NDimY, "V page-index Y dim must be valid"); + // kPageBlockSize >= kN0: within-page offset only (SRD rebased per page via rebase_v_window) + // kPageBlockSize < kN0: global offset, must fit int32 statically_indexed_array v_offsets; // V physical pages array for use with kv_offset_array_transform // For V_KIterOuter > 1, we need V_PageIdxRepeat elements; otherwise V_KIterInner @@ -843,6 +902,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync number<1>{}, // NumCoord VPageIndexYDims); + // Initial V SRD rebase + rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); + // prefetch K tile async_load_tile_raw( k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); @@ -946,6 +1008,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // V physical pages already prefetched before GEMM0 update_v_offsets(number{}); v_dram_window.update_page_idx(v_offsets); + rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); // KV_BLOCKSCALE: apply k_descale to s_acc (dequantize QK result) if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) @@ -1124,6 +1187,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf update_v_offsets(number<2 * kK1>{}); v_dram_window.update_page_idx(v_offsets); + rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); } __builtin_amdgcn_sched_barrier(0); @@ -1283,6 +1347,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Update V offsets using previously prefetched physical pages update_v_offsets(number<(2 + i_k1.value) * kK1>{}); v_dram_window.update_page_idx(v_offsets); + rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); } // Prefetch V physical pages for NEXT iteration - overlaps with GEMM1 @@ -1361,6 +1426,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kVectorSize>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); + rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) __builtin_amdgcn_s_barrier();