From 4933100b0fc7754db35a692fa7a796ddfff6de21 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 3 Feb 2026 09:00:42 +0800 Subject: [PATCH] use statically_indexed_array instead of c-style array. --- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 57 ++++++++++--------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 7622778c89..62e67a1fe1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -21,7 +21,8 @@ namespace ck_tile { // - Crosses pages: per-token lookup // - Single page: lane0 lookup once, broadcast to all // Output: physical_pages array with kLoopCount elements -template {}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; - physical_pages[k0.value] = page_idx[page_id]; + const index_t page_id = global_token_idx >> kLog2PageSize; + physical_pages[k0] = page_idx[page_id]; }); } else @@ -73,7 +74,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - physical_pages[k0.value] = page_idx[global_token_idx]; + physical_pages[k0] = page_idx[global_token_idx]; }); } else if constexpr(kVTileCrossesPages) @@ -83,8 +84,8 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; - physical_pages[k0.value] = page_idx[page_id]; + const index_t page_id = global_token_idx >> kLog2PageSize; + physical_pages[k0] = page_idx[page_id]; }); } else @@ -96,7 +97,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, const index_t shared_physical_page = page_idx[lane0_page_id]; static_for<0, kLoopCount, 1>{}( - [&](auto k0) { physical_pages[k0.value] = shared_physical_page; }); + [&](auto k0) { physical_pages[k0] = shared_physical_page; }); } } } @@ -123,7 +124,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, // LINEAR_LAYOUT: [page, token_in_page, head_dim] // VECTORIZED_LAYOUT: [page, token_in_page/kVectorSize, head_dim, kVectorSize] // -template -CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_pages)[kLoopCount], +CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages, const index_t& stride_token, const index_t& stride_page_block, const CoordVecType& coord_vec, - OffsetVecType& kv_offset_vec, + IndexArrayType& kv_offset_vec, index_t global_seq_offset = 0) { static constexpr index_t kLog2PageSize = [] { @@ -164,7 +165,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_page const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - const index_t physical_page = physical_pages[k0.value]; + const index_t physical_page = physical_pages[k0]; kv_offset_vec[k0] = static_cast(physical_page) * stride_page_block + static_cast(token_idx_in_page) * stride_token; @@ -181,7 +182,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_page const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - const index_t physical_page = physical_pages[k0.value]; + const index_t physical_page = physical_pages[k0]; const long_index_t page_base_offset = static_cast(physical_page) * stride_page_block; @@ -574,8 +575,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Load physical pages first, then compute offsets. // k_physical_pages can be reused for descale lookup later. - index_t k_physical_pages[NRepeat] = {}; - load_physical_pages k_physical_pages{}; + load_physical_pages, + decltype(k_coord), 0, kPageBlockSize, 0, @@ -737,7 +739,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync statically_indexed_array v_offsets; // V physical pages array for use with kv_offset_array_transform // For V_KIterOuter > 1, we need V_PageIdxRepeat elements; otherwise V_KIterInner - index_t v_physical_pages[V_PageIdxRepeat] = {}; + statically_indexed_array v_physical_pages{}; // Prefetch V physical pages - can be called early to hide buffer load latency auto prefetch_v_physical_pages = [&](auto k_loop_start) { @@ -746,8 +748,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { static_for<0, V_KIterOuter, 1>{}([&](auto k2) { // Load physical pages for this k2 slice into the appropriate portion of array - index_t v_physical_pages_k2[V_KIterInner] = {}; - load_physical_pages v_physical_pages_k2{}; + load_physical_pages, + decltype(v_coord), I1, kPageBlockSize, kLoopStart + k2.value * V_KLanes * V_KIterInner, @@ -759,14 +762,15 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Copy to merged array static_for<0, V_KIterInner, 1>{}([&](auto k1) { - constexpr auto idx = k1.value + k2.value * V_KIterInner; - v_physical_pages[idx] = v_physical_pages_k2[k1.value]; + constexpr auto idx = number{}; + v_physical_pages[idx] = v_physical_pages_k2[k1]; }); }); } else { - load_physical_pages, + decltype(v_coord), I1, kPageBlockSize, kLoopStart, @@ -789,10 +793,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static_for<0, V_KIterOuter, 1>{}([&](auto k2) { statically_indexed_array v_offsets_k2; // Extract physical pages for this k2 slice - index_t v_physical_pages_k2[V_KIterInner]; + statically_indexed_array v_physical_pages_k2; static_for<0, V_KIterInner, 1>{}([&](auto k1) { - constexpr auto idx = k1.value + k2.value * V_KIterInner; - v_physical_pages_k2[k1.value] = v_physical_pages[idx]; + constexpr auto idx = number{}; + v_physical_pages_k2[k1] = v_physical_pages[idx]; }); kv_offset_array_transform, @@ -893,7 +897,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) { const index_t scale_offset = - k_physical_pages[0] * kv_block_descale_stride_block + + k_physical_pages[number<0>{}] * kv_block_descale_stride_block + block_indices.kv_head_idx * kv_block_descale_stride_head; k_descale = kv_block_descale_ptr[scale_offset + 0 * kv_block_descale_stride_kv]; v_descale = kv_block_descale_ptr[scale_offset + 1 * kv_block_descale_stride_kv]; @@ -1342,7 +1346,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); // KV_BLOCKSCALE: reload physical pages for the new tile - load_physical_pages, + decltype(k_coord), 0, kPageBlockSize, 0,