From ece69df994115f719f2cd035be84235caf4ef3ab Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Mon, 19 Jan 2026 20:56:23 +0800 Subject: [PATCH] Improve execution time of batch prefill kernel with vectorized KV cache layout - Use column-major distribution for V to enable buffer_load_dwordx4 - Adapt page lookup mechanism to support column-major distribution in pipeline --- .../core/tensor/tile_scatter_gather.hpp | 112 +++++++-- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 154 +++++++----- ...pipeline_qr_ks_vs_async_default_policy.hpp | 235 +++++++++++++++++- 3 files changed, 420 insertions(+), 81 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 2ffaff2973..1097188567 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -34,9 +34,9 @@ template + index_t HsGatherDim = 0, + index_t NumCoord = 1, + typename YsGatherDims = sequence<0>> struct tile_scatter_gather { using BottomTensorView = remove_reference_t; @@ -77,6 +77,51 @@ struct tile_scatter_gather using BottomTensorCoord = decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); + CK_TILE_DEVICE static constexpr bool is_gather_dim(index_t i) + { + bool found = false; + static_for<0, YsGatherDims::size(), 1>{}([&](auto k) { + if(i == YsGatherDims::at(k)) + found = true; + }); + return found; + } + + template + CK_TILE_DEVICE static constexpr auto get_gather_index(const YsIndex& idx_ys_start) + { + // TODO: Consider making ys_lengths_ part of public API or adding accessor + static_assert(sizeof(TileDstr::DstrEncode::detail::ys_lengths_) > 0, + "Relies on internal detail::ys_lengths_"); + + constexpr index_t num_gather_dims = YsGatherDims::size(); + + if constexpr(num_gather_dims == 1) + { + return idx_ys_start[number{}]; + } + else + { + // Recursive lambda to compute index as a compile-time number + auto recurse = [&](auto self, auto i_constant) { + constexpr index_t i = decltype(i_constant)::value; + constexpr index_t dim = YsGatherDims::at(i); + auto current_val = idx_ys_start[number{}]; + + if constexpr(i + 1 < num_gather_dims) + { + constexpr index_t len = TileDstr::DstrEncode::detail::ys_lengths_[dim]; + return current_val + self(self, number{}) * number{}; + } + else + { + return current_val; + } + }; + return recurse(recurse, number<0>{}); + } + } + struct load_store_traits { private: @@ -375,7 +420,7 @@ struct tile_scatter_gather // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = idx_ys_start[number{}]; + constexpr auto idx_gather = get_gather_index(idx_ys_start); const auto page_offset = page_idx_[idx_gather]; // read from bottom tensor @@ -427,7 +472,7 @@ struct tile_scatter_gather constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, number{}); constexpr auto idx_diff_ps_ys = container_concat( @@ -485,7 +530,7 @@ struct tile_scatter_gather // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = idx_ys_start[number{}]; + constexpr auto idx_gather = get_gather_index(idx_ys_start); const auto page_offset = page_idx_[idx_gather]; // merge page_offset into bottom_coord @@ -513,7 +558,7 @@ struct tile_scatter_gather constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, number{}); constexpr auto idx_diff_ps_ys = container_concat( @@ -598,7 +643,7 @@ struct tile_scatter_gather }(); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = idx_ys_start[number{}]; + constexpr auto idx_gather = get_gather_index(idx_ys_start); const auto page_offset = page_idx_[idx_gather]; // read from bottom tensor @@ -624,7 +669,7 @@ struct tile_scatter_gather constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, number{}); constexpr auto idx_diff_ps_ys = container_concat( @@ -718,7 +763,7 @@ struct tile_scatter_gather }(); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = idx_ys_start[number{}]; + constexpr auto idx_gather = get_gather_index(idx_ys_start); const auto page_offset = page_idx_[idx_gather]; auto mixed_bottom_thread_coord = bottom_tensor_thread_coord; @@ -748,7 +793,7 @@ struct tile_scatter_gather constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, number{}); constexpr auto idx_diff_ps_ys = container_concat( @@ -791,7 +836,7 @@ struct tile_scatter_gather // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = idx_ys_start[number<0>{}]; + constexpr auto idx_gather = get_gather_index(idx_ys_start); const auto page_offset = page_idx_[idx_gather]; // read from distributed tensor @@ -837,7 +882,7 @@ struct tile_scatter_gather constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, number{}); constexpr auto idx_diff_ps_ys = container_concat( @@ -874,11 +919,11 @@ struct tile_scatter_gather // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = idx_ys_start[number<0>{}]; + constexpr auto idx_gather = get_gather_index(idx_ys_start); const auto page_offset = page_idx_[idx_gather]; // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n", - // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0); + // get_gather_index(idx_ys_start)+0, idx_ys_start[number<1>{}]+0); // read from distributed tensor // vector_type_t vec; @@ -928,7 +973,7 @@ struct tile_scatter_gather constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, number{}); constexpr auto idx_diff_ps_ys = container_concat( @@ -1076,10 +1121,12 @@ struct tile_scatter_gather }; // TODO: use strategy +// Overload for sequence based gather dimensions template CK_TILE_DEVICE constexpr auto @@ -1088,6 +1135,7 @@ make_tile_scatter_gather(const TensorView_& tensor_view, const multi_index& origin, const StaticTileDistribution_& tile_distribution, const StaticPageIndexArray_& page_idx, // perbytes + sequence, number = {}, number = {}) { @@ -1097,7 +1145,37 @@ make_tile_scatter_gather(const TensorView_& tensor_view, remove_cvref_t, std::nullptr_t, HsGatherDim, - NumCoord>{ + NumCoord, + sequence>{ + tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; +} + +// Legacy overload (compatible with original API) +template +CK_TILE_DEVICE constexpr auto +make_tile_scatter_gather(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + const StaticPageIndexArray_& page_idx, // perbytes + number = {}, + number = {}, + number = {}) +{ + return tile_scatter_gather, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + std::nullptr_t, + HsGatherDim, + NumCoord, + sequence>{ tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index efd2adf257..75c204e8b4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -508,32 +508,97 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto randval_dram_window = dropout.template MakeRandvalDramWindow( randval_dram_block_window_tmp, seqlen_k_start); - auto v_dist = Policy::template MakeVDramTileDistribution(); - auto v_coord = v_dist.calculate_index(); - const auto VPageIndexDim = I1; - using VDstrEncode = typename decltype(v_dist)::DstrEncode; - constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3]; - statically_indexed_array v_offsets; - kv_offset_array_transform, - decltype(v_coord), - VPageIndexDim, - kPageBlockSize, - 0, - V_KRepeat, - 1, - kKVMemoryLayout, - false, - kN0, - kVectorSize>( - page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + auto v_dist = Policy::template MakeVDramTileDistribution(); + auto v_coord = v_dist.calculate_index(); + using VDstrEncode = typename decltype(v_dist)::DstrEncode; + constexpr index_t V_Hs1Size = VDstrEncode::hs_lengthss_[I1].size(); + constexpr index_t V_KInnerRepeat = VDstrEncode::hs_lengthss_[I1].back(); + constexpr index_t V_KMiddle = [] { + if constexpr(V_Hs1Size == 3) + return VDstrEncode::hs_lengthss_[I1][I1]; + else + return VDstrEncode::hs_lengthss_[I1][I0]; + }(); + constexpr index_t V_KOuterRepeat = [] { + if constexpr(V_Hs1Size == 3) + return VDstrEncode::hs_lengthss_[I1][I0]; + else + return 1; + }(); + constexpr index_t V_PageIdxRepeat = V_KInnerRepeat * V_KOuterRepeat; + + constexpr auto VPageIndexYDims = []() { + constexpr index_t K1Minor = VDstrEncode::hs_lengthss_[I1].size() - 1; + constexpr index_t Y_K1 = VDstrEncode::detail::rhs_major_minor_to_ys_[2][K1Minor]; + + if constexpr(V_Hs1Size == 3) + { + constexpr index_t Y_K2 = VDstrEncode::detail::rhs_major_minor_to_ys_[2][I0]; + return sequence{}; + } + else + { + return sequence{}; + } + }(); + + static_assert(decltype(VPageIndexYDims)::at(0) < VDstrEncode::NDimY, + "V page-index Y dim must be valid"); + + statically_indexed_array v_offsets; + auto update_v_offsets = [&](auto k_loop_start) { + constexpr index_t kLoopStart = decltype(k_loop_start)::value; + // For 3D K decomposition (K2, K0, K1), compute offsets for each K2 slice + // The global K offset for (k2, k1) is: kLoopStart + k2 * (K0 * K1) + k1 + // We iterate K2 outer, K1 inner, and merge into 1D v_offsets array + if constexpr(V_KOuterRepeat > 1) + { + static_for<0, V_KOuterRepeat, 1>{}([&](auto k2) { + statically_indexed_array v_offsets_k2; + kv_offset_array_transform, + decltype(v_coord), + I1, + kPageBlockSize, + kLoopStart + k2.value * V_KMiddle * V_KInnerRepeat, + V_KInnerRepeat, + 1, + kKVMemoryLayout, + false, + kN0, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets_k2, current_seq_k); + static_for<0, V_KInnerRepeat, 1>{}([&](auto k1) { + constexpr auto idx = number{}; + v_offsets[idx] = v_offsets_k2[k1]; + }); + }); + } + else + { + kv_offset_array_transform, + decltype(v_coord), + I1, + kPageBlockSize, + kLoopStart, + V_KInnerRepeat, + 1, + kKVMemoryLayout, + false, + kN0, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + } + }; + update_v_offsets(number<0>{}); auto v_dram_window = make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), {0, seqlen_k_start}, // TODO: hdim split? v_dist, v_offsets, - VPageIndexDim); + VPageIndexYDims, + number<1>{}); // prefetch K tile async_load_tile_raw( @@ -600,18 +665,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(1); auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); - kv_offset_array_transform, - decltype(v_coord), - VPageIndexDim, - kPageBlockSize, - kK1, - V_KRepeat, - 1, - kKVMemoryLayout, - false, - kN0, - kVectorSize>( - page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + update_v_offsets(number{}); v_dram_window.update_page_idx(v_offsets); const auto p = [&]() { @@ -741,7 +795,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(0x7F); // store & prefetch next v, after the max reduction - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && + kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT) { auto v_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledVRegBlockDescriptor()); @@ -762,8 +818,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync get_slice_tile(v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + const auto v_store_tile = tile_elementwise_in(v_element_func, v_buf); + store_tile(v_lds_window_tmp, v_store_tile); // store the prefetch } if constexpr(k1_loops > 1) @@ -774,18 +830,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kK1}); // will have scratch if move this right after load_tile(v_dram)... v_buf = load_tile( v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf - kv_offset_array_transform, - decltype(v_coord), - VPageIndexDim, - kPageBlockSize, - 2 * kK1, - V_KRepeat, - 1, - kKVMemoryLayout, - false, - kN0, - kVectorSize>( - page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + update_v_offsets(number<2 * kK1>{}); v_dram_window.update_page_idx(v_offsets); } __builtin_amdgcn_sched_barrier(0); @@ -913,18 +958,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { v_buf = load_tile( v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf - kv_offset_array_transform, - decltype(v_coord), - VPageIndexDim, - kPageBlockSize, - (2 + i_k1.value) * kK1, - V_KRepeat, - 1, - kKVMemoryLayout, - false, - kN0, - kVectorSize>( - page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + update_v_offsets(number<(2 + i_k1.value) * kK1>{}); v_dram_window.update_page_idx(v_offsets); } block_sync_lds(); @@ -936,7 +970,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && + kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT) { auto v_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledVRegBlockDescriptor()); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp index 33e6ad006a..61cc6a6a12 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp @@ -4,15 +4,240 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" namespace ck_tile { // This pipeline is qkv all located in LDS -using BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy = - BlockFmhaPipelineQXKSVSCustomPolicy; +struct BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy + : BlockFmhaPipelineQXKSVSCustomPolicy +{ + using Base = BlockFmhaPipelineQXKSVSCustomPolicy; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + if constexpr(Problem::kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + using VDataType = remove_cvref_t; + constexpr index_t kDwordx4Bytes = 16; + return kDwordx4Bytes / sizeof(VDataType); + } + else + { + return Base::template GetAlignmentV(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + if constexpr(Problem::kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // For VECTORIZED_LAYOUT, kKPack should match GEMM's kKPerThread + // to ensure correct LDS access pattern + constexpr auto gemm_k_decomp = GetGemmKDecomposition(); + constexpr index_t kKPerThread = gemm_k_decomp.template at<1>(); + return kKPerThread; + } + else + { + return Base::template GetSmemKPackV(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() + { + if constexpr(Problem::kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // For VECTORIZED_LAYOUT, we need to use our GetSmemKPackV for V size calculation + constexpr index_t SingleKSize = [&]() { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = Base::template GetSmemKPackK(); + constexpr index_t KVector = Base::template GetAlignmentK(); + constexpr index_t kPad = KPack; + + static_assert(WarpSize * KVector >= kKPerBlock && + WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; + constexpr index_t LaneGroups = WarpSize / LanesPerK; + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + + return NumIssues * NumWarps * (WarpSize * KVector + kPad); + }(); + + constexpr index_t SingleVSize = [&]() { + using VDataType = remove_cvref_t; + constexpr index_t Banks = get_n_lds_banks(); + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); // Use our override! + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + }(); + + return max(SingleKSize, SingleVSize); + } + else + { + return Base::template GetSingleSmemElementSpaceSize(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + if constexpr(Problem::kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + using VDataType = remove_cvref_t; + constexpr index_t Banks = get_n_lds_banks(); + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number()>{}, + number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_merge_transform(make_tuple(number{}, + number{}, + number{})), + make_merge_transform( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + else + { + return Base::template MakeVLdsBlockDescriptor(); + } + } + + // Helper to get GEMM's K decomposition parameters (kABKLane, kKPerThread) + template + CK_TILE_HOST_DEVICE static constexpr auto GetGemmKDecomposition() + { + // Get the KV block GEMM and extract warp gemm's K decomposition + constexpr auto gemm = Base::template GetKVBlockGemm(); + using BlockGemm = remove_cvref_t; + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + // Return kABKLane and kKPerThread from warp gemm + return make_tuple(number{}, + number{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + if constexpr(Problem::kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // For VECTORIZED_LAYOUT, use column-major distribution (K direction vector load) + // The K decomposition must match GEMM's BWarpDstrEncoding to ensure correct LDS access + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + // Get GEMM's K decomposition (kABKLane, kKPerThread) + constexpr auto gemm_k_decomp = GetGemmKDecomposition(); + constexpr index_t kABKLane = gemm_k_decomp.template at<0>(); + constexpr index_t kKPerThread = gemm_k_decomp.template at<1>(); + + // K1 = kKPerThread (inner K dimension, matches GEMM's expectation) + // K0 = kKPerBlock / K1 (outer K dimension) + // But we need K0 to match kABKLane for the per-warp iteration + constexpr index_t K1 = kKPerThread; + constexpr index_t K0 = kABKLane; + // K0 * K1 may be less than kKPerBlock, so we need outer iteration + constexpr index_t KPerIter = K0 * K1; + constexpr index_t KOuterIter = kKPerBlock / KPerIter; + + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + static_assert(N0 != 0, "N0 is zero"); + + if constexpr(KOuterIter == 1) + { + // Simple case: K decomposition matches exactly + constexpr auto dstr = make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<2, 1>, + sequence<1, 0>>{}); + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock); + return dstr; + } + else + { + // Need outer K iteration + constexpr index_t K2 = KOuterIter; + constexpr auto dstr = make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<2, 1, 2>, + sequence<2, 0, 0>>{}); + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock); + return dstr; + } + } + else + { + // For non-VECTORIZED_LAYOUT, use base class implementation + return Base::template MakeVDramTileDistribution(); + } + } +}; } // namespace ck_tile