From 29c56b8aae7a4cf57c4af2caa9ece576d2703719 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Thu, 29 Jan 2026 07:18:41 +0800 Subject: [PATCH] Optimize batch prefill kernel performance for VECTORIZED_LAYOUT KV cache (#3657) - Add multi-dimensional page index support (YsGatherDims) in tile_scatter_gather - Add is_gather_dim() and get_gather_index() for multi-dim page lookup - Override MakeVDramTileDistribution() for VECTORIZED_LAYOUT to match GEMM's BWarpDstrEncoding (K decomposition: {K2, K0, K1}) - Add GetGemmKDecomposition() to retrieve kABKLane and kKPerThread - Add static_assert for RowMajor VLayout requirement in batch prefill Co-authored-by: Po Yen Chen [ROCm/composable_kernel commit: e3556fed0453e66cdebc5dad6b903f5e902cd9b4] --- .../core/tensor/tile_scatter_gather.hpp | 166 ++++++++++-- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 227 ++++++++++++----- ...pipeline_qr_ks_vs_async_default_policy.hpp | 241 +++++++++++++++++- .../pipeline/block_fmha_pipeline_problem.hpp | 3 + 4 files changed, 553 insertions(+), 84 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 2ffaff2973..aa29345892 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -26,17 +26,26 @@ namespace ck_tile { * * @tparam BottomTensorView_ Class describing & holding device tensor memory. * @tparam WindowLengths_ Spatial sizes of windowed view on tensor. - * @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions - * @tparam NumCoord TBD + * @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions. + * @tparam StaticPageIndexArray_ Array type holding page indices for scatter/gather. + * @tparam StaticValidArray_ Array type holding validity flags (nullptr_t if unused). + * @tparam HsGatherDim H-space dimension index used for gather lookup (default: 0). + * @tparam NumCoord Number of pre-computed coordinates for pipelining (default: 1). + * @tparam YsGatherDims Sequence of Y-space dimension indices used for page lookup. + * For single dimension: sequence<0> (default). + * For multiple dimensions: sequence where + * the combined index is computed as: + * idx[dim0] + idx[dim1] * len[dim0] + idx[dim2] * len[dim0] * + * len[dim1] + ... */ template + index_t HsGatherDim = 0, + index_t NumCoord = 1, + typename YsGatherDims = sequence<0>> struct tile_scatter_gather { using BottomTensorView = remove_reference_t; @@ -77,6 +86,75 @@ struct tile_scatter_gather using BottomTensorCoord = decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); + /** + * @brief Check if a given Y-space dimension index is a gather dimension. + * + * Gather dimensions are those specified in YsGatherDims template parameter. + * When computing forward_step_scatter, gather dimensions are set to 0 + * because page offset lookup handles address calculation for these dimensions. + * + * @param i Y-space dimension index to check + * @return true if dimension i is in YsGatherDims, false otherwise + */ + CK_TILE_DEVICE static constexpr bool is_gather_dim(index_t i) + { + return sequence_any_of(YsGatherDims{}, [i](auto k) { return i == k; }); + } + + /** + * @brief Compute the linearized gather index from Y-space indices for page lookup. + * + * This function converts multi-dimensional Y-space indices (specified by YsGatherDims) + * into a single linearized index used to look up the page offset in page_idx_ array. + * + * For single gather dimension (YsGatherDims::size() == 1): + * Simply returns idx_ys_start[YsGatherDims::at(0)] + * + * For multiple gather dimensions (e.g., YsGatherDims = sequence<0, 2>): + * Computes: idx[dim0] + idx[dim1] * len[dim0] + idx[dim2] * len[dim0] * len[dim1] + ... + * This is row-major linearization where earlier dimensions are inner (faster-varying). + * + * @tparam YsIndex Type of the Y-space index tuple/array + * @param idx_ys_start Current Y-space indices from space-filling curve iteration + * @return Linearized index for page_idx_ array lookup + */ + template + CK_TILE_DEVICE static constexpr auto get_gather_index(const YsIndex& idx_ys_start) + { + // TODO: Consider making ys_lengths_ part of public API or adding accessor + static_assert(sizeof(TileDstr::DstrEncode::detail::ys_lengths_) > 0, + "Relies on internal detail::ys_lengths_"); + + constexpr index_t num_gather_dims = YsGatherDims::size(); + + if constexpr(num_gather_dims == 1) + { + return idx_ys_start[number{}]; + } + else + { + // Recursive lambda to compute index as a compile-time number + // Uses row-major linearization: idx[0] + idx[1] * len[0] + idx[2] * len[0] * len[1] + + // ... + auto recurse = [&](auto self, auto i_constant) { + constexpr index_t i = decltype(i_constant)::value; + constexpr index_t dim = YsGatherDims::at(i); + auto current_val = idx_ys_start[number{}]; + + if constexpr(i + 1 < num_gather_dims) + { + constexpr index_t len = TileDstr::DstrEncode::detail::ys_lengths_[dim]; + return current_val + self(self, number{}) * number{}; + } + else + { + return current_val; + } + }; + return recurse(recurse, number<0>{}); + } + } + struct load_store_traits { private: @@ -375,7 +453,7 @@ struct tile_scatter_gather // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = idx_ys_start[number{}]; + constexpr auto idx_gather = get_gather_index(idx_ys_start); const auto page_offset = page_idx_[idx_gather]; // read from bottom tensor @@ -427,7 +505,7 @@ struct tile_scatter_gather constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, number{}); constexpr auto idx_diff_ps_ys = container_concat( @@ -485,7 +563,7 @@ struct tile_scatter_gather // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = idx_ys_start[number{}]; + constexpr auto idx_gather = get_gather_index(idx_ys_start); const auto page_offset = page_idx_[idx_gather]; // merge page_offset into bottom_coord @@ -513,7 +591,7 @@ struct tile_scatter_gather constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, number{}); constexpr auto idx_diff_ps_ys = container_concat( @@ -598,7 +676,7 @@ struct tile_scatter_gather }(); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = idx_ys_start[number{}]; + constexpr auto idx_gather = get_gather_index(idx_ys_start); const auto page_offset = page_idx_[idx_gather]; // read from bottom tensor @@ -624,7 +702,7 @@ struct tile_scatter_gather constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, number{}); constexpr auto idx_diff_ps_ys = container_concat( @@ -718,7 +796,7 @@ struct tile_scatter_gather }(); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = idx_ys_start[number{}]; + constexpr auto idx_gather = get_gather_index(idx_ys_start); const auto page_offset = page_idx_[idx_gather]; auto mixed_bottom_thread_coord = bottom_tensor_thread_coord; @@ -748,7 +826,7 @@ struct tile_scatter_gather constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, number{}); constexpr auto idx_diff_ps_ys = container_concat( @@ -791,7 +869,7 @@ struct tile_scatter_gather // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = idx_ys_start[number<0>{}]; + constexpr auto idx_gather = get_gather_index(idx_ys_start); const auto page_offset = page_idx_[idx_gather]; // read from distributed tensor @@ -837,7 +915,7 @@ struct tile_scatter_gather constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, number{}); constexpr auto idx_diff_ps_ys = container_concat( @@ -874,11 +952,11 @@ struct tile_scatter_gather // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = idx_ys_start[number<0>{}]; + constexpr auto idx_gather = get_gather_index(idx_ys_start); const auto page_offset = page_idx_[idx_gather]; // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n", - // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0); + // get_gather_index(idx_ys_start)+0, idx_ys_start[number<1>{}]+0); // read from distributed tensor // vector_type_t vec; @@ -928,7 +1006,7 @@ struct tile_scatter_gather constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, number{}); constexpr auto idx_diff_ps_ys = container_concat( @@ -1076,6 +1154,53 @@ struct tile_scatter_gather }; // TODO: use strategy +/** + * @brief Factory function to create tile_scatter_gather with multi-dimensional gather support. + * + * This overload accepts a sequence to specify multiple Y-space dimensions + * for page lookup. Use this when the tile distribution decomposes the paged dimension + * into multiple Y-space dimensions (e.g., VECTORIZED_LAYOUT V tensor with K decomposition + * {K2, K0, K1} where both Y0 and Y2 contribute to page index). + * + * @tparam HsGatherDim H-space dimension for gather + * @tparam NumCoord Number of pre-computed coordinates + * @tparam YsGatherDims Parameter pack specifying which Y-dimensions are used for page lookup + * + * @param tensor_view The underlying tensor view for device memory access + * @param window_lengths Static window sizes for each dimension + * @param origin Window origin coordinates on the bottom tensor + * @param tile_distribution Thread-to-tile mapping distribution + * @param page_idx Array of page offsets (in bytes) for scatter/gather + */ +template +CK_TILE_DEVICE constexpr auto +make_tile_scatter_gather(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + const StaticPageIndexArray_& page_idx, + number, + number, + sequence) +{ + return tile_scatter_gather, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + std::nullptr_t, + HsGatherDim, + NumCoord, + sequence>{ + tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; +} + +// Legacy overload (compatible with original API) template & origin, const StaticTileDistribution_& tile_distribution, - const StaticPageIndexArray_& page_idx, // perbytes + const StaticPageIndexArray_& page_idx, number = {}, number = {}) { @@ -1097,7 +1222,8 @@ make_tile_scatter_gather(const TensorView_& tensor_view, remove_cvref_t, std::nullptr_t, HsGatherDim, - NumCoord>{ + NumCoord, + sequence<0>>{ tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index c75f5d58c4..48e8f75ae7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -533,32 +533,170 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto randval_dram_window = dropout.template MakeRandvalDramWindow( randval_dram_block_window_tmp, seqlen_k_start); - auto v_dist = Policy::template MakeVDramTileDistribution(); - auto v_coord = v_dist.calculate_index(); - const auto VPageIndexDim = I1; - using VDstrEncode = typename decltype(v_dist)::DstrEncode; - constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3]; - statically_indexed_array v_offsets; - kv_offset_array_transform, - decltype(v_coord), - VPageIndexDim, - kPageBlockSize, - 0, - V_KRepeat, - 1, - kKVMemoryLayout, - false, - kN0, - kVectorSize>( - page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + auto v_dist = Policy::template MakeVDramTileDistribution(); + auto v_coord = v_dist.calculate_index(); + using VDstrEncode = typename decltype(v_dist)::DstrEncode; + // V tensor K-dimension decomposition for page index computation + // ============================================================ + // The K dimension (seqlen_k) in V distribution is decomposed into multiple sub-dimensions. + // This decomposition determines how threads iterate over the K dimension and how page + // indices are computed for paged KV cache. + // + // The decomposition pattern differs by memory layout: + // + // VECTORIZED_LAYOUT (ColumnMajor, custom distribution): + // 3D decomposition: K = K2 × K0 × K1 + // - K2 (V_KIterOuter): Outer iteration count + // - K0 (V_KLanes): Lanes for K dimension (matches GEMM kABKLane) + // - K1 (V_KIterInner): Vector load size (matches GEMM kKPerThread) + // - hs_lengthss_[I1] = {K2, K0, K1}, size = 3 (or {K0, K1} size = 2 if no outer iter) + // + // LINEAR_LAYOUT ColumnMajor (base class distribution): + // 2D decomposition: K = K0 × K1 + // - K0: Lanes for K dimension (may not match GEMM kABKLane) + // - K1: Vector load size + // - hs_lengthss_[I1] = {K0, K1}, size = 2 + // + // LINEAR_LAYOUT RowMajor (base class distribution): + // 4D decomposition: K = K0 × K1 × K2 × K3 (uses shuffle_tile for GEMM alignment) + // 3D decomposition: K = K0 × K1 × K2 (fallback case) + // - Page lookup uses Y-space's last dimension only (inner iteration) + // + // V_PageIdxRepeat = total number of page lookups per thread = V_KIterOuter × V_KIterInner + constexpr index_t V_KIterInner = VDstrEncode::hs_lengthss_[I1].back(); + + // Compute V_KIterOuter and V_KLanes based on memory layout and K decomposition + constexpr index_t V_KIterOuter = [] { + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // VECTORIZED_LAYOUT: 3D decomposition {K2, K0, K1} when outer iteration is needed + if constexpr(VDstrEncode::hs_lengthss_[I1].size() == 3) + return static_cast(VDstrEncode::hs_lengthss_[I1][I0]); + else + return index_t{1}; + } + else + { + // LINEAR_LAYOUT: No outer iteration for page lookup + // RowMajor uses shuffle_tile, ColumnMajor has simple 2D decomposition + // Both cases use single-dimension Y-space page lookup + return index_t{1}; + } + }(); + + constexpr index_t V_KLanes = [] { + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // VECTORIZED_LAYOUT: K0 is the lanes dimension + if constexpr(V_KIterOuter > 1) + return static_cast(VDstrEncode::hs_lengthss_[I1][I1]); + else + return static_cast(VDstrEncode::hs_lengthss_[I1][I0]); + } + else + { + // LINEAR_LAYOUT: First dimension is K0 (lanes) + return static_cast(VDstrEncode::hs_lengthss_[I1][I0]); + } + }(); + + // This affects page offset computation - need to track offsets for each (k2, k1) + // combination + constexpr index_t V_PageIdxRepeat = V_KIterInner * V_KIterOuter; + + // VPageIndexYDims: Y-space dimension indices that participate in page index computation + // ================================================================================ + // In tile_scatter_gather, the gather index is computed from Y-space coordinates. + // This sequence specifies which Y dimensions should be linearized to form the page lookup + // index. + // + // VECTORIZED_LAYOUT with outer iteration: sequence + // - Both K1 and K2 are in Y-space (thread iteration dimensions) + // - gather_index = y_k1 + y_k2 * len(Y_K1) (linearized 2D -> 1D) + // + // VECTORIZED_LAYOUT without outer iteration / LINEAR_LAYOUT: sequence + // - Only the innermost K dimension is used for page lookup (single dimension) + // + constexpr auto VPageIndexYDims = []() { + // K1Minor is always the last element index in hs_lengthss_[I1] + constexpr index_t K1Minor = VDstrEncode::hs_lengthss_[I1].size() - 1; + constexpr index_t Y_K1 = VDstrEncode::detail::rhs_major_minor_to_ys_[2][K1Minor]; + + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT && + V_KIterOuter > 1) + { + // VECTORIZED_LAYOUT with outer iteration: need 2D page lookup + constexpr index_t Y_K2 = VDstrEncode::detail::rhs_major_minor_to_ys_[2][I0]; + return sequence{}; + } + else + { + // LINEAR_LAYOUT or VECTORIZED_LAYOUT without outer iteration: 1D page lookup + return sequence{}; + } + }(); + + static_assert(decltype(VPageIndexYDims)::at(0) < VDstrEncode::NDimY, + "V page-index Y dim must be valid"); + + statically_indexed_array v_offsets; + auto update_v_offsets = [&](auto k_loop_start) { + constexpr index_t kLoopStart = decltype(k_loop_start)::value; + // For 3D K decomposition (K2, K0, K1), compute offsets for each K2 slice + // The global K offset for (k2, k1) is: kLoopStart + k2 * (K0 * K1) + k1 + // We iterate K2 outer, K1 inner, and merge into 1D v_offsets array + if constexpr(V_KIterOuter > 1) + { + static_for<0, V_KIterOuter, 1>{}([&](auto k2) { + statically_indexed_array v_offsets_k2; + kv_offset_array_transform, + decltype(v_coord), + I1, + kPageBlockSize, + kLoopStart + k2.value * V_KLanes * V_KIterInner, + V_KIterInner, + 1, + kKVMemoryLayout, + false, + kN0, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets_k2, current_seq_k); + static_for<0, V_KIterInner, 1>{}([&](auto k1) { + constexpr auto idx = number{}; + v_offsets[idx] = v_offsets_k2[k1]; + }); + }); + } + else + { + kv_offset_array_transform, + decltype(v_coord), + I1, + kPageBlockSize, + kLoopStart, + V_KIterInner, + 1, + kKVMemoryLayout, + false, + kN0, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + } + }; + update_v_offsets(number<0>{}); auto v_dram_window = make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), {0, seqlen_k_start}, // TODO: hdim split? v_dist, v_offsets, - VPageIndexDim); + number<1>{}, // HsGatherDim + number<1>{}, // NumCoord + VPageIndexYDims); // prefetch K tile async_load_tile_raw( @@ -625,18 +763,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(1); auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); - kv_offset_array_transform, - decltype(v_coord), - VPageIndexDim, - kPageBlockSize, - kK1, - V_KRepeat, - 1, - kKVMemoryLayout, - false, - kN0, - kVectorSize>( - page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + update_v_offsets(number{}); v_dram_window.update_page_idx(v_offsets); const auto p = [&]() { @@ -766,7 +893,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(0x7F); // store & prefetch next v, after the max reduction - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && + kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT) { auto v_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledVRegBlockDescriptor()); @@ -787,8 +916,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync get_slice_tile(v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + const auto v_store_tile = tile_elementwise_in(v_element_func, v_buf); + store_tile(v_lds_window_tmp, v_store_tile); // store the prefetch } if constexpr(k1_loops > 1) @@ -799,18 +928,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kK1}); // will have scratch if move this right after load_tile(v_dram)... v_buf = load_tile( v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf - kv_offset_array_transform, - decltype(v_coord), - VPageIndexDim, - kPageBlockSize, - 2 * kK1, - V_KRepeat, - 1, - kKVMemoryLayout, - false, - kN0, - kVectorSize>( - page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + update_v_offsets(number<2 * kK1>{}); v_dram_window.update_page_idx(v_offsets); } __builtin_amdgcn_sched_barrier(0); @@ -938,18 +1056,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { v_buf = load_tile( v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf - kv_offset_array_transform, - decltype(v_coord), - VPageIndexDim, - kPageBlockSize, - (2 + i_k1.value) * kK1, - V_KRepeat, - 1, - kKVMemoryLayout, - false, - kN0, - kVectorSize>( - page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + update_v_offsets(number<(2 + i_k1.value) * kK1>{}); v_dram_window.update_page_idx(v_offsets); } block_sync_lds(); @@ -961,7 +1068,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && + kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT) { auto v_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledVRegBlockDescriptor()); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp index 33e6ad006a..45b7356dfa 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp @@ -4,15 +4,246 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" namespace ck_tile { // This pipeline is qkv all located in LDS -using BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy = - BlockFmhaPipelineQXKSVSCustomPolicy; +struct BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy + : BlockFmhaPipelineQXKSVSCustomPolicy +{ + using Base = BlockFmhaPipelineQXKSVSCustomPolicy; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + if constexpr(Problem::kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + using VDataType = remove_cvref_t; + constexpr index_t kDwordx4Bytes = 16; + return kDwordx4Bytes / sizeof(VDataType); + } + else + { + return Base::template GetAlignmentV(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + if constexpr(Problem::kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // For VECTORIZED_LAYOUT, kKPack should match GEMM's kKPerThread + // to ensure correct LDS access pattern + constexpr auto gemm_k_decomp = GetGemmKDecomposition(); + constexpr index_t kKPerThread = gemm_k_decomp.template at<1>(); + return kKPerThread; + } + else + { + return Base::template GetSmemKPackV(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() + { + if constexpr(Problem::kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // For VECTORIZED_LAYOUT, we need to use our GetSmemKPackV for V size calculation + constexpr index_t SingleKSize = [&]() { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = Base::template GetSmemKPackK(); + constexpr index_t KVector = Base::template GetAlignmentK(); + constexpr index_t kPad = KPack; + + static_assert(WarpSize * KVector >= kKPerBlock && + WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; + constexpr index_t LaneGroups = WarpSize / LanesPerK; + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + + return NumIssues * NumWarps * (WarpSize * KVector + kPad); + }(); + + constexpr index_t SingleVSize = [&]() { + using VDataType = remove_cvref_t; + constexpr index_t Banks = get_n_lds_banks(); + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); // Use our override! + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + }(); + + return max(SingleKSize, SingleVSize); + } + else + { + return Base::template GetSingleSmemElementSpaceSize(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + if constexpr(Problem::kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + using VDataType = remove_cvref_t; + constexpr index_t Banks = get_n_lds_banks(); + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number()>{}, + number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_merge_transform(make_tuple(number{}, + number{}, + number{})), + make_merge_transform( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + else + { + return Base::template MakeVLdsBlockDescriptor(); + } + } + + // Helper to get GEMM's K decomposition parameters (kABKLane, kKPerThread) + template + CK_TILE_HOST_DEVICE static constexpr auto GetGemmKDecomposition() + { + // Get the KV block GEMM and extract warp gemm's K decomposition + constexpr auto gemm = Base::template GetKVBlockGemm(); + using BlockGemm = remove_cvref_t; + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + // Return kABKLane and kKPerThread from warp gemm + return make_tuple(number{}, + number{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + if constexpr(Problem::kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // For VECTORIZED_LAYOUT, use column-major distribution (K direction vector load) + // The K decomposition must match GEMM's BWarpDstrEncoding to ensure correct LDS access + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + // Get GEMM's K decomposition (kABKLane, kKPerThread) + constexpr auto gemm_k_decomp = GetGemmKDecomposition(); + constexpr index_t kABKLane = gemm_k_decomp.template at<0>(); + constexpr index_t kKPerThread = gemm_k_decomp.template at<1>(); + + // K1 = kKPerThread (inner K dimension, matches GEMM's expectation) + // K0 = kKPerBlock / K1 (outer K dimension) + // But we need K0 to match kABKLane for the per-warp iteration + constexpr index_t K1 = kKPerThread; + constexpr index_t K0 = kABKLane; + + // Verify K decomposition matches GEMM's BWarpDstrEncoding requirements + static_assert(K0 == kABKLane, "K0 must match GEMM's kABKLane for correct LDS access"); + static_assert(K1 == kKPerThread, + "K1 must match GEMM's kKPerThread for correct LDS access"); + + // K0 * K1 may be less than kKPerBlock, so we need outer iteration + constexpr index_t KPerIter = K0 * K1; + constexpr index_t KOuterIter = kKPerBlock / KPerIter; + + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + static_assert(N0 != 0, "N0 is zero"); + + if constexpr(KOuterIter == 1) + { + // Simple case: K decomposition matches exactly + constexpr auto dstr = make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<2, 1>, + sequence<1, 0>>{}); + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock); + return dstr; + } + else + { + // Need outer K iteration + constexpr index_t K2 = KOuterIter; + constexpr auto dstr = make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<2, 1, 2>, + sequence<2, 0, 0>>{}); + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock); + return dstr; + } + } + else + { + // For non-VECTORIZED_LAYOUT, use base class implementation + return Base::template MakeVDramTileDistribution(); + } + } +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index a489eabb73..eabf74faf8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -121,6 +121,9 @@ struct BlockFmhaBatchPrefillPipelineProblem static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0, "kPageBlockSize must be divisible by kVectorSize for vectorized layout"); static_assert(kIsGroupMode_, "Batch prefill requires group mode"); + + static_assert(BlockFmhaShape_::IsVLayoutRowMajor, + "Batch prefill kernel requires RowMajor VLayout"); }; template