diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 24496cc755..367c569769 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -59,6 +59,25 @@ struct has_ignore_fast_exp2_flag< template static inline constexpr bool ignore_fast_exp2_v = has_ignore_fast_exp2_flag::value; +// A helper struct for detecting naive_hdim_load, naive_hdim_load means load tiles of +// hdim96/hdim160/hdim192 without padding the tensor_view/tile_window to hdim128/hdim256 +// naive_hdim_load is current supported by the qr_ks_vs_whole_k_prefetch_pipeline +template +struct has_naive_hdim_load_flag : std::false_type +{ +}; + +template +struct has_naive_hdim_load_flag< + T, + std::enable_if_t && + T::kIsNaiveHDimLoad>> : std::true_type +{ +}; + +template +static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag::value; + }; // namespace detail template @@ -1313,6 +1332,10 @@ struct FmhaFwdKernel static_cast(i_nhead) * kargs.nhead_stride_o + batch_offset_o; + constexpr index_t kQKHeaddimToUse = detail::is_naive_hdim_load_v + ? FmhaPipeline::kQKHeaddim + : FmhaPipeline::kSubQKHeaddim; + // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { const auto q_dram_naive = make_naive_tensor_view( @@ -1325,7 +1348,7 @@ struct FmhaFwdKernel { return pad_tensor_view(q_dram_naive, make_tuple(number{}, - number{}), + number{}), sequence{}); } else @@ -1350,7 +1373,7 @@ struct FmhaFwdKernel { return pad_tensor_view(k_dram_naive, make_tuple(number{}, - number{}), + number{}), sequence{}); } else @@ -1371,18 +1394,29 @@ struct FmhaFwdKernel number{}, number<1>{}); - const auto v_dram_transposed = transform_tensor_view( - v_dram_naive, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_k)), - make_tuple(sequence<1>{}, sequence<0>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + if constexpr(!kUseTrLoad) + { + const auto v_dram_transposed = transform_tensor_view( + v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; - return pad_tensor_view( - v_dram_transposed, - make_tuple(number{}, number{}), - sequence{}); + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }; } else { @@ -1406,7 +1440,7 @@ struct FmhaFwdKernel [&]() { if constexpr(FmhaPipeline::kQLoadOnce) return make_tuple(number{}, - number{}); + number{}); else return make_tuple(number{}, number{}); }(), @@ -1416,8 +1450,8 @@ struct FmhaFwdKernel if constexpr(detail::is_n0loop_pipeline_v) { return make_tile_window(k_dram, - make_tuple(number{}, - number{}), + make_tuple(number{}, + number{}), {0, 0}); } else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index b90b760a0d..90200d9e83 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -8,6 +8,52 @@ namespace ck_tile { +namespace detail { + +template +CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize() +{ + if constexpr(std::is_same_v || std::is_same_v) + { + // ToDo: need support in ck_tile for using buffer_load_dwordx3 + // if constexpr(ElemPerThread % 6 == 0) + // return 6; + if constexpr(ElemPerThread % 8 == 0) + return 8; + else if constexpr(ElemPerThread % 4 == 0) + return 4; + else if constexpr(ElemPerThread % 2 == 0) + return 2; + return 1; + } + else if constexpr(std::is_same_v) + { + // ToDo: need support in ck_tile for using buffer_load_dwordx3 + // if constexpr(ElemPerThread % 3 == 0) + // return 3; + if constexpr(ElemPerThread % 4 == 0) + return 4; + else if constexpr(ElemPerThread % 2 == 0) + return 2; + return 1; + } + else + static_assert(false, "The data type is not supported!"); +}; + +template +CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize() +{ + constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize; + + return GetMaxVectorSize(); +} + +}; // namespace detail + template (); + } + + CK_TILE_HOST_DEVICE static constexpr auto GetKDramTileAccessMaxVectorSize() + { + constexpr index_t kNPerBlock = BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = BlockFmhaShape::kK0; + + return detail:: + GetDramTileAccessMaxVectorSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto GetVDramTileAccessMaxVectorSize() + { + constexpr index_t kNPerBlock = BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = BlockFmhaShape::kK1; + + return detail:: + GetDramTileAccessMaxVectorSize(); + }; }; template CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile const KElementFunction& k_element_func, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VElementFunction& v_element_func, @@ -170,8 +171,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kK1 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kSubQKHeaddim == - KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && @@ -225,7 +225,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQDramSingleRepMTileDistribution()); @@ -235,7 +235,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); @@ -271,7 +271,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); auto q_lds_write_window = make_tile_window( q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window auto q_lds_read_window = make_tile_window(q_lds, make_tuple(number{}, number{}), @@ -286,25 +285,15 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch auto k_lds_window = make_tile_window( k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); - using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_window, sequence<0, 0>{}, sequence{})); - - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - using k_lds_read_window_type = decltype(get_slice_tile( + using k_lds_window_type = decltype(get_slice_tile( k_lds_window, sequence<0, 0>{}, sequence{})); - statically_indexed_array k_lds_write_windows; - statically_indexed_array k_lds_read_windows; + statically_indexed_array k_lds_windows; static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { - k_lds_write_windows[i_buf] = - get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = - get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); + k_lds_windows[i_buf] = get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); }); // V tile in LDS @@ -434,7 +423,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration { static_for<0, n0_loops, 1>{}([&](auto i_n0) { - store_tile(k_lds_write_windows[number{}], + store_tile(k_lds_windows[number{}], tile_elementwise_in(k_element_func, k_tiles[number{}]), partition_index); @@ -460,9 +449,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch }; block_sync_lds(); - gemm_0(sacc_tile, - q_tile, - k_lds_read_windows[number{}]); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); auto tmp_tile = cast_tile(sacc_tile); @@ -475,7 +463,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch else // the iteration is also the last iteration { static_for<0, n0_loops, 1>{}([&](auto i_n0) { - store_tile(k_lds_write_windows[number{}], + store_tile(k_lds_windows[number{}], tile_elementwise_in(k_element_func, k_tiles[number{}]), partition_index); @@ -492,9 +480,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch }; block_sync_lds(); - gemm_0(sacc_tile, - q_tile, - k_lds_read_windows[number{}]); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); auto tmp_tile = cast_tile(sacc_tile); @@ -510,7 +497,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration { static_for<0, n0_loops, 1>{}([&](auto i_n0) { - store_tile(k_lds_write_windows[number{}], + store_tile(k_lds_windows[number{}], tile_elementwise_in(k_element_func, k_tiles[number{}]), partition_index); @@ -525,9 +512,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch move_tile_window(k_dram_window, {kN0Sub, 0}); block_sync_lds(); - gemm_0(sacc_tile, - q_tile, - k_lds_read_windows[number{}]); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); auto tmp_tile = cast_tile(sacc_tile); @@ -540,7 +526,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch else // last iteration { static_for<0, n0_loops, 1>{}([&](auto i_n0) { - store_tile(k_lds_write_windows[number{}], + store_tile(k_lds_windows[number{}], tile_elementwise_in(k_element_func, k_tiles[number{}]), partition_index); @@ -551,9 +537,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch }; block_sync_lds(); - gemm_0(sacc_tile, - q_tile, - k_lds_read_windows[number{}]); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); auto tmp_tile = cast_tile(sacc_tile); @@ -568,7 +553,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch else // only preload one unroll of K for next iteration, used when kM0=128 { static_for<0, n0_loops, 1>{}([&](auto i_n0) { - store_tile(k_lds_write_windows[number{}], + store_tile(k_lds_windows[number{}], tile_elementwise_in(k_element_func, k_tiles[I0]), partition_index); @@ -590,7 +575,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch block_sync_lds(); - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); auto tmp_tile = cast_tile(sacc_tile); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 3e430ff476..1fcf823abf 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -114,16 +114,21 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() { - using QDataType = remove_cvref_t; + if constexpr(Problem::kLoadWholeQTileOnceThroughLds) + { + return Problem::GetQDramTileAccessMaxVectorSize(); + } + else + { + using QDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t MaxVectorSize = 16 / sizeof(QDataType); - constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; - - return min(MaxVectorSize, ElemPerThread); + return detail:: + GetDramTileAccessMaxVectorSize(); + }; } template @@ -142,12 +147,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - - return min(MaxVectorSize, ElemPerThread); + return detail:: + GetDramTileAccessMaxVectorSize(); } template @@ -162,23 +165,31 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { - using VDataType = remove_cvref_t; + // special consideration when shuffling is required before storing V to LDS + if constexpr(!Problem::kUseTrLoad) + { + using VDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); + constexpr index_t kMaxVecLoad = Problem::GetVDramTileAccessMaxVectorSize(); + constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); - constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad) - ? kMaxVecLoad - : (ElemPerThread / kMinVecLoad); + // try to avoid writing sub-dword to LDS due to poor performance + constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad) + ? kMaxVecLoad + : (ElemPerThread / kMinVecLoad); - return kVecLoad; + return kVecLoad; + } + else + { + return Problem::GetVDramTileAccessMaxVectorSize(); + }; } template @@ -195,11 +206,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize() { constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); - if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + // for hdim96 and hdim160 + if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim) + { + return kKPerBlock * kNPerBlock; + } + else if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) { static_assert(kKVector == kKPack); @@ -236,12 +252,23 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() { - constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds + ? Problem::BlockFmhaShape::kM0 + : GetQKBlockGemmSingleRepM(); + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackQ(); constexpr index_t kKVector = GetAlignmentQ(); - if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + // for hdim96 and hdim160, use simplest layout + if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim) + { + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + } + else if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) { static_assert(kKVector == kKPack); @@ -324,25 +351,113 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKVector = GetAlignmentQ(); + constexpr index_t OtherK = kKPerBlock / kKVector; - constexpr index_t KPerThread = kKVector; - constexpr index_t KThreads = kKPerBlock / KPerThread; - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim) + // for kKPerBlock=32,64,128,256 + { + static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); - // for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E] - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<1, 2>, - sequence<2, 1>>{}); + constexpr index_t KPerThread = kKVector; + constexpr index_t KThreads = OtherK; + + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + // for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + else // for kKPerBlock=96,160 + { + static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!"); + + // ToDo: need more considieration for hdim72 + constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5; + constexpr index_t KThreads = OtherK / KRepPerThread; + + static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!"); + + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<1, 0, 2>>{}); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + + constexpr index_t kKVector = GetAlignmentQ(); + constexpr index_t OtherK = kKPerBlock / kKVector; + + if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim) + // for kKPerBlock=32,64,128,256 + { + static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); + + constexpr index_t KPerThread = kKVector; + constexpr index_t KThreads = OtherK; + + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + // for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + else // for kKPerBlock=96,160 + { + static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!"); + + // ToDo: need more considieration for hdim72 + constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5; + constexpr index_t KThreads = OtherK / KRepPerThread; + + static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!"); + + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<1, 0, 2>>{}); + }; } template @@ -350,11 +465,36 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy { constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers(); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); - if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + // for hdim96 and hdim160, use simplest layout + if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim) + { + constexpr index_t KSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock; + + static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + else if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) { static_assert(kKVector == kKPack); @@ -362,9 +502,15 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy constexpr index_t DataTypeSize = sizeof(KDataType); +#ifdef __gfx950__ + // 256 contiguous bytes mapped to 64 banks with each bank 4 contiguous bytes + constexpr auto NLdsLayer = + (64 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (64 * 4 / kKPerBlock / DataTypeSize); +#else // 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes constexpr auto NLdsLayer = (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); +#endif constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(make_tuple(number{}, @@ -455,24 +601,52 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKVector = GetAlignmentK(); + constexpr index_t OtherK = kKPerBlock / kKVector; - constexpr index_t KPerThread = kKVector; - constexpr index_t KThreads = kKPerBlock / KPerThread; - constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); + if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim) + // for kKPerBlock=32,64,128,256 + { + static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + constexpr index_t KPerThread = kKVector; + constexpr index_t KThreads = OtherK; + + constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else // for kKPerBlock=96,160 + { + static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!"); + + constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5; + constexpr index_t KThreads = OtherK / KRepPerThread; + + constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + }; } template @@ -483,43 +657,87 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - // K2 is the vector size for storing shuffled tile to LDS - constexpr index_t K2 = ElemPerThread / N1; + // K2 is the vector size for storing shuffled tile to LDS + constexpr index_t K2 = ElemPerThread / N1; - // GetSmemKPackV() is the vector size for loading from LDS by BlockGemm - constexpr index_t kKPack = GetSmemKPackV(); + // GetSmemKPackV() is the vector size for loading from LDS by BlockGemm + constexpr index_t kKPack = GetSmemKPackV(); - static_assert(kKPack >= K2, "Check failed!"); + static_assert(kKPack >= K2, "Check failed!"); - constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack); + constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack); - static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); - constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}, number{}), - make_tuple(number{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple( + number{}, number{}, number{}, number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple(make_merge_transform( - make_tuple(number{}, number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - return v_lds_block_desc; + return v_lds_block_desc; + } + else + { + constexpr index_t kKPack = GetSmemKPackV(); + + constexpr auto XorGroupSize = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); + + constexpr index_t VSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock; + + static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + + constexpr auto v_lds_block_desc_naive = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor( + v_lds_block_desc_naive, + make_tuple(make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + v_lds_block_desc_permuted, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + }; } template @@ -529,24 +747,46 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t NPerThread = GetAlignmentV(); + constexpr index_t NThreads = kNPerBlock / NPerThread; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(ElemPerThread % N1 == 0); + constexpr index_t KPerThread = ElemPerThread / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t K2 = ElemPerThread / N1; - constexpr index_t K1 = get_warp_size() / N0; - constexpr index_t K0 = kBlockSize / get_warp_size(); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<2, 1>, + sequence<2, 1>>{}); + } + else + { + constexpr index_t NPerThread = GetAlignmentV(); + constexpr index_t NThreads = kNPerBlock / NPerThread; - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<2, 1>>{}); + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t KPerThread = ElemPerThread / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + }; } template @@ -556,20 +796,19 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t NPerThread = GetAlignmentV(); + constexpr index_t NThreads = kNPerBlock / NPerThread; constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(ElemPerThread % N1 == 0); - - constexpr index_t K2 = ElemPerThread / N1; - constexpr index_t K1 = get_warp_size() / N0; - constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t KPerThread = ElemPerThread / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, + sequence>, tuple, sequence<2, 1>>, tuple, sequence<1, 0>>, sequence<1, 2>,