From 7771db8ecf8f678f9f88f405f08c8645530d654b Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Fri, 7 Mar 2025 14:19:51 +0800 Subject: [PATCH] Ck tile/complete k prefetch (#1941) * Re-implement qr_ks_vs_async pipeline by using kLoadOnce * Remove last block_sync_lds() in the loop * Tiny adjustment in qr_ks_vs_async pipeline for better performance * Rename MakeQDramTileDistribution to MakeQRegTileDistribution for QLoadOnce pipeline * Use LDS as intermediary stop when loading Q from global memory for qr_ks_vs_async pipeline * Use un-rolled gemm for Gemm-0 * Use k0_loops small tile load/store to replace the big tile load/store for K * Remove the commented lines in qx_ks_vs_custom_policy.hpp * Tune the prefetching of V in qr_ks_vs_async pipeline * Move the codes for storing the first v_lds tile some later * Let BlockDropout reuse LDS with V * Switch to separate code blocks according to iteration index * Interleave code blocks for better performance * Move clear_tile(s_acc) for better interleaving * Move code interleaving * Use MakeQDramTileDistribution for q_dram_window * Roll-back to load Q directly from global memory instead of using LDS as intermediary stop * Let V reuse the LDS of K * Use array of tiles to represent Q in vgprs * Use QLoadOnce == false for qr_ks_vs_async pipeline * Special treatment for hdim-96 to save vgprs in qr_ks_vs_async pipeline * Define statically indexed array k_lds_windows[] to reduce the using of get_slice_tile() * Move the definition of v_tiles out from the loop * Define statically indexed array v_lds_windows[] to reduce using of get_slice_tile() * Remove using KLoadOnce in qx_ks_vs_custom_policy * Remove un-used get_slice_tile() call * Move the code line of clear_tile(s_acc) * Tune the lines of codes to make them more tidy * Re-arrange the codes before the main-loop * Add comments * Unify the alignment to be 8 for Q/K/V Lds decriptors * Tuning to K pre-loading * Tune K Lds and V Lds reuse for kPreloadWholeNextIterationK == false * Adjust the pipeline codes * Use NumPrefetchV to separate from NumVLdsBuffers * Tune the location of a scheduler barrier code line * Prefetch first v_tile at earlier time for both kPreloadNextWholeIterationK true/false paths * Adjust the using of kPadSeqLenQ and kPadSeqLenK in the kernel * Use __builtin_amdgcn_sched_barrier(0x7f) in the pipeline * Move the location for store_tile() of first v_tile * Rename the qr_ks_vs_async pipeline to qr_ks_vs_whole_k_prefetch pipeline * Re-add NumPrefetchK as template for BlockFmhaPipelineQXKSVSCustomPolicy<> * Try to fix old bugs in qx_ks_vs_custom_policy * Remove K_LDS_LOAD_USE_OFFSET_TRANSFORM code-path to make qr_ks_vs_async and qx_ks_vs_custom_policy simpler * Fix in MakeKDramTileDistribution() in qx_ks_vs_custom_policy * Update to LdsBufferSequence and introduce NumKVLdsBuffers for max(NumPrefetchK, NumPrefetchV) * Tiny Fix (#1888) * Ck tile/paged attention workaround (#1894) * Correction in GetRangeAlongX() * Work-around to solve the failures in test_paged_attention_ck in xformers * Tiny code adjustment in the qr_ks_vs_whole_k_prefetch pipeline * Remove one call of move_tile_window for q_dram_window * Refine the codes in GetNumPrefetchV()/GetNumKLdsBuffers() * Tiny fix in qr_ks_vs_whole_k_prefetch pipeline * Adjust the location of codes for storing the first V tile to LDS * Tiny fix and add comments * Change GetSmemKPackK size to improve performance * Move the codes related to K-Lds to the pipeline default policy due to some override on the generic custom_policy * Update MakeKDramTileDistribution() and MakeKLdsDescriptor() to completely remove bank conflicts for K-Lds access * Adjustment in intermediate iteration codes for tiny performance improvement * Reduce the number of VLds buffers to 2 for whole_k_prefetch situtation * Use IsFirstKLdsBufferOverlapLastVLdsBuffer() to avoid potential Lds issue * Adjust the code location for calling IsFirstKLdsBufferOverlapLastVLdsBuffer() * Remove useless AsyncopyV * Rename MakeQDramTileDistribution to MakeQRegTileDistribution when LDS is not used * Keep qx_ks_vs_custom_policy work for other pipelines and move whole_k_prefetch specific codes to whole_k_prefetch default policy * Recover the qr_ks_vs_async pipeline * Recover qr_ks_vs_async in fmha.hpp and tiny fix in qr_ks_vs pipeline * Revert "Try to fix old bugs in qx_ks_vs_custom_policy" This reverts commit 39b82ca194f5926aec8fd15af5696b6ed131f6d8. * Tiny fix with regard to whole_k_prefetch pipeline compiling * Update kPadSeqLenK setting in fmha_fwd_kernel * Use q_element_func and k_element_func * Use single q_tile rather than multiple sliced q_tiles * Codes refine according to the comments * Re-format one file * Mark qr_ks_vs_whole_k_prefetch as QLoadOnec == true [ROCm/composable_kernel commit: 4f54fa30583704f34da2ac50372d524cae6bad7d] --- include/ck_tile/ops/fmha.hpp | 4 +- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 11 +- ...litkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 16 +- ...nwarp_sshuffle_qr_ks_vs_default_policy.hpp | 10 +- ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 15 +- ...litkv_pipeline_qr_ks_vs_default_policy.hpp | 3 +- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 15 +- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 53 +- ...pipeline_qr_ks_vs_async_default_policy.hpp | 3 +- ..._fmha_pipeline_qr_ks_vs_default_policy.hpp | 4 +- ...mha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | 929 ++++++++++++++++++ ..._ks_vs_whole_k_prefetch_default_policy.hpp | 379 +++++++ .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 2 + ..._fmha_pipeline_qs_ks_vs_default_policy.hpp | 3 +- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 116 +-- 15 files changed, 1418 insertions(+), 145 deletions(-) create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index c896534e03..2618082e5b 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -33,9 +33,11 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index f107b10dff..c671463252 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -54,6 +54,8 @@ struct FmhaFwdKernel using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; + static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; + // clang-format off template struct t2s; template <> struct t2s { static constexpr const char * name = "fp32"; }; @@ -1082,10 +1084,11 @@ struct FmhaFwdKernel number{}, number<1>{}); + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; return pad_tensor_view( k_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); const auto v_dram = [&]() { if constexpr(std::is_same_v) @@ -1104,10 +1107,11 @@ struct FmhaFwdKernel make_tuple(sequence<1>{}, sequence<0>{}), make_tuple(sequence<0>{}, sequence<1>{})); + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; return pad_tensor_view( v_dram_transposed, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } else { @@ -1118,10 +1122,11 @@ struct FmhaFwdKernel number{}, number<1>{}); + constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; return pad_tensor_view( v_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 3d53535b28..809c58f1d1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -44,6 +44,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; @@ -97,6 +99,10 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS { return 1; } + else + { + return 1; + } } }(); @@ -316,11 +322,11 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS // load Q from LDS __builtin_amdgcn_sched_barrier(0); - auto q_lds_window_for_load = make_tile_window( - q_lds, - Policy::template MakeQLdsBlockDescriptor().get_lengths(), - {0, 0}, - Policy::template MakeQRegTileDistribution()); + auto q_lds_window_for_load = + make_tile_window(q_lds, + Policy::template MakeQLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeQRegTileDistribution()); block_sync_lds(); auto q = load_tile(q_lds_window_for_load); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp index 74d755ef39..9d8f6bc99f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp @@ -13,14 +13,12 @@ namespace ck_tile { // This pipeline is qkv all located in LDS struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy : BlockFmhaPipelineQXKSVSCustomPolicy { using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy; @@ -76,10 +74,10 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy sequence<0, 1>>{}); } - template + template CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() { - return BasePolicy::template MakeQDramTileDistribution(); + return BasePolicy::template MakeQRegTileDistribution(); } template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 04aa85644d..ce80dba5eb 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -43,6 +43,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; @@ -96,6 +98,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS { return 1; } + else + { + return 1; + } } }(); @@ -180,11 +186,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - auto q_dram_window = make_tile_window( - q_dram_block_window_tmp.get_bottom_tensor_view(), - q_dram_block_window_tmp.get_window_lengths(), - q_dram_block_window_tmp.get_window_origin(), - Policy::template MakeQDramTileDistribution()); + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); auto q = load_tile(q_dram_window); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp index b7f1f042ed..ccc4f23817 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp @@ -11,8 +11,7 @@ namespace ck_tile { // This pipeline is qkv all located in LDS struct BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy : BlockFmhaPipelineQXKSVSCustomPolicy { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index a7e9287143..8a4a925b81 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -45,6 +45,8 @@ struct BlockFmhaPipelineQRKSVS static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; @@ -96,6 +98,10 @@ struct BlockFmhaPipelineQRKSVS { return 1; } + else + { + return 1; + }; } }(); @@ -178,11 +184,10 @@ struct BlockFmhaPipelineQRKSVS constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - auto q_dram_window = make_tile_window( - q_dram_block_window_tmp.get_bottom_tensor_view(), - q_dram_block_window_tmp.get_window_lengths(), - q_dram_block_window_tmp.get_window_origin(), - Policy::template MakeQDramTileDistribution()); + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); auto q = load_tile(q_dram_window); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 173887513e..d64e5562d0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -46,6 +46,8 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) // only need special care about seq_k padding (oob need set -INF of p instead of zero) @@ -114,6 +116,10 @@ struct BlockFmhaPipelineQRKSVSAsync { return 1; } + else + { + return 1; + }; } }(); @@ -189,19 +195,8 @@ struct BlockFmhaPipelineQRKSVSAsync Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), {0, 0, 0}); }, - number{}); + number{}); -#if K_LDS_LOAD_USE_OFFSET_TRANSFORM - auto k_lds_load = generate_tuple( - [&](auto i_buf) { - return make_tile_window( - make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor(i_buf)), - Policy::template MakeKLdsLoadBlockDescriptor(i_buf).get_lengths(), - {0, 0}); - }, - number{}); -#else auto k_lds_Load_view = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); @@ -209,7 +204,6 @@ struct BlockFmhaPipelineQRKSVSAsync make_tile_window(k_lds_Load_view, Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), {0, 0}); -#endif // V tile in LDS auto v_lds = make_tensor_view( @@ -222,11 +216,10 @@ struct BlockFmhaPipelineQRKSVSAsync constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - auto q_dram_window = make_tile_window( - q_dram_block_window_tmp.get_bottom_tensor_view(), - q_dram_block_window_tmp.get_window_lengths(), - q_dram_block_window_tmp.get_window_origin(), - Policy::template MakeQDramTileDistribution()); + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); q_dram_window.init_raw(); // TODO: we use async Copy for K, which is inline asm @@ -368,14 +361,9 @@ struct BlockFmhaPipelineQRKSVSAsync gemm_0(s_acc, get_slice_tile( q, sequence<0, i_k0 * kK0>{}, sequence{}), -#if K_LDS_LOAD_USE_OFFSET_TRANSFORM - k_lds_load[number{})>{}]); - -#else get_slice_tile(k_lds_load, sequence<(LdsSeq.at(number{})) * kN0, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); -#endif }); } @@ -391,18 +379,13 @@ struct BlockFmhaPipelineQRKSVSAsync auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); __builtin_amdgcn_sched_barrier(0); { // tail - gemm_0(s_acc, - get_slice_tile( - q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), -#if K_LDS_LOAD_USE_OFFSET_TRANSFORM - k_lds_load[number{})>{}]); - -#else - get_slice_tile( - k_lds_load, - sequence<(LdsSeq.at(number{})) * kN0, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); -#endif + gemm_0( + s_acc, + get_slice_tile( + q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); } __builtin_amdgcn_sched_barrier(1); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp index 7824bbdefb..e92ba58b37 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp @@ -11,8 +11,7 @@ namespace ck_tile { // This pipeline is qkv all located in LDS using BlockFmhaPipelineQRKSVSAsyncDefaultPolicy = BlockFmhaPipelineQXKSVSCustomPolicy; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp index 6ce4591aff..e905037398 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp @@ -8,11 +8,9 @@ namespace ck_tile { -// This pipeline is qkv all located in LDS using BlockFmhaPipelineQRKSVSDefaultPolicy = BlockFmhaPipelineQXKSVSCustomPolicy; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp new file mode 100644 index 0000000000..cc532040e8 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -0,0 +1,929 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaPipelineQRKSVSWholeKPrefetch +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kQKHeaddim == 32) + { + return 2; + } + else if constexpr(kQKHeaddim == 64) + { + return 2; + } + else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim == 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_async"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + DropoutType& dropout) const + { + ignore = q_element_func; + ignore = k_element_func; + + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + static_assert(2 <= k0_loops); + static_assert(2 <= k1_loops); + + constexpr bool kPreloadWholeNextIterationK = + Policy::template IsPreloadWholeNextIterationK(); + + constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers(); + constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers(); + constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV(); + + static_assert(NumKLdsBuffers >= 2); + + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + auto k_dram_window = + make_tile_window(k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); + + using k_tile_type = decltype(load_tile(k_dram_window)); + + auto k_tiles = [&]() { + if constexpr(kPreloadWholeNextIterationK) + return statically_indexed_array{}; + else + return statically_indexed_array{}; + }(); + + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + + auto q_tile = load_tile(q_dram_window); + + __builtin_amdgcn_sched_barrier(0); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(smem_ptr); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = make_tile_window( + k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); + + using k_lds_window_type = + decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array k_lds_windows; + + static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) { + k_lds_windows[i_buf] = get_slice_tile( + k_lds_window, sequence{}, sequence<(i_buf + 1) * kN0, kK0>{}); + }); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetExclusiveKLdsBytes()), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + using v_tile_type = decltype(load_tile(v_dram_window)); + + statically_indexed_array v_tiles; + + using v_lds_window_type = + decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array v_lds_windows; + + static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) { + v_lds_windows[i_buf] = get_slice_tile( + v_lds_window, sequence{}, sequence<(i_buf + 1) * kN1, kK1>{}); + }); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -numeric::infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + + q_tile = tile_elementwise_in(q_element_func, q_tile); + + index_t i_total_loops = 0; + + do + { + if constexpr(kPreloadWholeNextIterationK) + { + if(i_total_loops == 0) // executed by fist iteration + { + if(num_total_loop > 1) // there are multiple iterations + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + store_tile( + k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}])); + + k_tiles[number{}] = load_tile(k_dram_window); + if constexpr(i_k0 < k0_loops - 2) + move_tile_window(k_dram_window, {0, kK0}); + + if constexpr(i_k0 == 0) + clear_tile(s_acc); + + block_sync_lds(); + // execute current unroll of gemm_0 + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_windows[number{}]); + }); + + store_tile( + k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], + tile_elementwise_in(k_element_func, k_tiles[number{}])); + + // prefetch first v_tile + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + + move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0}); + + // prefetch all k_tiles for next iteration + static_for<0, k0_loops, 1>{}([&](auto i_k0) { + k_tiles[number{}] = load_tile(k_dram_window); + + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + }); + + move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); + + block_sync_lds(); + // execute last unroll of gemm_0 + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); + } + else // there is only single iteration + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + store_tile( + k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}])); + + k_tiles[number{}] = load_tile(k_dram_window); + if constexpr(i_k0 < k0_loops - 2) + move_tile_window(k_dram_window, {0, kK0}); + + if constexpr(i_k0 == 0) + clear_tile(s_acc); + + block_sync_lds(); + // execute current unroll of gemm_0 + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_windows[number{}]); + }); + + store_tile( + k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], + tile_elementwise_in(k_element_func, k_tiles[number{}])); + + // prefetch first v_tile + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); + + // move_tile_window(k_dram_window, {0, -k0_loops * kK0}); + } + } + else // executed by intermediate and last iteration + { + if(i_total_loops < num_total_loop - 1) // intermediate iteration + { + store_tile(k_lds_windows[I0], + tile_elementwise_in(k_element_func, k_tiles[I0])); + + // prefetch first v_tile + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + + clear_tile(s_acc); + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, sequence<0, 0>{}, sequence{}), + k_lds_windows[I0]); + + store_tile(k_lds_windows[I1], + tile_elementwise_in(k_element_func, k_tiles[I1])); + + move_tile_window(k_dram_window, {kN0, 0}); + + // prefetch first k_tile for next iteration + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + + k_tiles[I1] = load_tile(k_dram_window); + if constexpr(1 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, sequence<0, kK0>{}, sequence{}), + k_lds_windows[I1]); + + // during the gemm-loop, also prefetch other k_tiles for next iteration + static_for<2, k0_loops, 1>{}([&](auto i_k0) { + store_tile(k_lds_windows[number{}], + k_tiles[number{}]); + + k_tiles[number{}] = load_tile(k_dram_window); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_windows[number{}]); + }); + + move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); + } + else // last iteration + { + store_tile(k_lds_windows[I0], + tile_elementwise_in(k_element_func, k_tiles[I0])); + + // prefetch first v_tile + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + + clear_tile(s_acc); + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, sequence<0, 0>{}, sequence{}), + k_lds_windows[I0]); + + static_for<1, k0_loops, 1>{}([&](auto i_k0) { + store_tile( + k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}])); + + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_windows[number{}]); + }); + }; + }; + } + else // only preload one unroll of K for next iteration + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[I0])); + if constexpr(i_k0 == 0) + clear_tile(s_acc); + + if constexpr(i_k0 < k0_loops - 1) + k_tiles[I0] = load_tile(k_dram_window); + if constexpr(i_k0 < k0_loops - 2) + move_tile_window(k_dram_window, {0, kK0}); + + block_sync_lds(); + // execute current unroll of gemm_0 + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_windows[number{}]); + }); + + store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], + tile_elementwise_in(k_element_func, k_tiles[I0])); + + // prefetch first v_tile + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); + }; + + __builtin_amdgcn_sched_barrier(0); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + + static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) { + v_tiles[i_buf] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }); + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + if constexpr(kHasDropout) + { + auto randval_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeK(); + dropout.template Run( + smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); + } + + __builtin_amdgcn_sched_barrier(0x7f); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_tiles[I0]); + + store_tile( + v_lds_windows[I0], + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_windows[I0], + tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch + } + + __builtin_amdgcn_sched_barrier(0); + + const auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + + if constexpr(!kPreloadWholeNextIterationK) + { + if(i_total_loops < num_total_loop - 1) + { + move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0}); + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + }; + + __builtin_amdgcn_sched_barrier(0); + } + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + if constexpr(NumPrefetchV == 1) // NumVLdsBuffers == 2 + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + v_tiles[I0] = load_tile(v_dram_window); + + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number{}]); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_tiles[I0]); + store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffle_tmp)); + } + else + { + store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_tiles[I0])); + } + + move_tile_window(v_dram_window, {0, kK1}); + }); + } + else // NumVLdsBuffers == 3 or 2 + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 < k1_loops - NumPrefetchV) + v_tiles[number{}] = load_tile(v_dram_window); + + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number{}]); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, + v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]); + store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffle_tmp)); + } + else + { + store_tile( + v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in(v_element_func, + v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); + } + + if constexpr(i_k1 < k1_loops - NumPrefetchV) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]); + + if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer()) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + }; + + } while(++i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } +#else + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + DropoutType& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp new file mode 100644 index 0000000000..67ab548dab --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy + : BlockFmhaPipelineQXKSVSCustomPolicy +{ + static constexpr index_t NumPrefetchV = 2; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsPreloadWholeNextIterationK() + { + return Problem::BlockFmhaShape::kM0 <= 64; + }; + + template + CK_TILE_DEVICE static constexpr auto GetNumKLdsBuffers() + { + return 2; + } + + template + CK_TILE_DEVICE static constexpr auto GetNumPrefetchV() + { + using BlockFmhaShape = remove_cvref_t; + + constexpr index_t kN0 = BlockFmhaShape::kN0; + constexpr index_t kK1 = BlockFmhaShape::kK1; + + constexpr index_t k1_loops = kN0 / kK1; + + return min(NumPrefetchV, k1_loops); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumVLdsBuffers() + { + return 2; + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + return BlockGemm::template MakeABlockTileDistribution< + Problem::BlockFmhaShape::kM0, + Problem::BlockFmhaShape::kQKHeaddim>(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + { + using KDataType = remove_cvref_t; + return 8 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() + { + constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPack = GetSmemKPackK(); + constexpr index_t kKVector = GetAlignmentK(); + + static_assert(kKVector % kKPack == 0); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() + { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); + + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + + constexpr index_t KPerThread = kMaxVecLoad; + constexpr index_t KThreads = kKPerBlock / KPerThread; + constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + using VDataType = remove_cvref_t; + + constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers(); + + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + constexpr index_t VSingleSmemElementSpaceSize = + (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple( + number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + using VLayout = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + if constexpr(std::is_same_v) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(ElemPerThread % N1 == 0); + constexpr index_t K3 = ElemPerThread / N1; + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = GetAlignmentV(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + static_assert(N0 != 0); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + + constexpr auto warp_gemm = []() { + constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + if constexpr(WarpGemmM == 32) + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; + else if constexpr(WarpGemmM == 16) + return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; + else // WarpGemmM == 4 + return WarpGemmMfmaF16F16F32M4N64K16{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + if constexpr(WarpGemmM == 32) + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; + else if constexpr(WarpGemmM == 16) + return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; + else // WarpGemmM == 4 + return WarpGemmMfmaBf16Bf16F32M4N64K16{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + static_assert(WarpGemmM == 32); + + // TODO: hard coded here. Otherwise, it may incorrect result + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } // TODO - bf8_t + }(); + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + + if constexpr(1 < Problem::kNumGemm0Warps) + return BlockGemmARegBSmemCRegV2{}; + else + return BlockGemmARegBSmemCRegOneWarpV1{}; + } + + // leave some exclusive space so that the second v_lds buffer will nenver overlap with the first + // k_lds bufffer + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes() + { + constexpr index_t single_k_lds_buffer_size = + GetSmemSizeK() / GetNumKLdsBuffers(); + constexpr index_t single_v_lds_buffer_size = + GetSmemSizeV() / GetNumVLdsBuffers(); + + if constexpr(single_k_lds_buffer_size <= single_v_lds_buffer_size) + return 0; + else + return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64); + }; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer() + { + using BlockFmhaShape = remove_cvref_t; + + constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1; + constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers(); + constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers(); + + constexpr index_t last_v_lds_buffer_offset = + MakeVLdsBlockDescriptor().get_element_space_size() / num_v_lds_buffers * + ((k1_loops - 1) % num_v_lds_buffers) * sizeof(typename Problem::VDataType); + + constexpr index_t first_k_lds_buffer_size = + MakeKLdsBlockDescriptor().get_element_space_size() / num_k_lds_buffers * + sizeof(typename Problem::KDataType); + + return GetExclusiveKLdsBytes() + last_v_lds_buffer_offset < + first_k_lds_buffer_size; + }; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() + { + return MakeKLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() + { + return MakeVLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + // assume V can reuse the other shared memory by K except the first + // assume Dropout can reuse the shared memory by V + return GetExclusiveKLdsBytes() + + max(GetSmemSizeK() - GetExclusiveKLdsBytes(), + max(GetSmemSizeV(), GetSmemSizeDropout(0))); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index c2223fcee6..7be6a347f5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -94,6 +94,8 @@ struct BlockFmhaPipelineQSKSVS { return 1; } + else + return 1; } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp index ff8299b4ff..7505dbb172 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp @@ -11,8 +11,7 @@ namespace ck_tile { // This pipeline is qkv all located in LDS struct BlockFmhaPipelineQSKSVSDefaultPolicy : BlockFmhaPipelineQXKSVSCustomPolicy { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 3db461e971..26f7e46f9f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -17,9 +17,6 @@ #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" -// TODO: remove this -#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0 - namespace ck_tile { template @@ -50,9 +47,11 @@ struct BlockFmhaPipelineQXCustomPolicy return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() { + using BlockGemm = remove_cvref_t())>; + return BlockGemm::template MakeABlockTileDistribution< Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kSubQKHeaddim>(); @@ -278,37 +277,43 @@ struct BlockFmhaPipelineQXCustomPolicy }; // This pipeline is qkv all located in LDS -template +template struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy { - static constexpr bool AsyncCopyK = AsyncCopyK_; - static constexpr bool AsyncCopyV = AsyncCopyV_; // TODO: this not supported yet + static constexpr bool AsyncCopy = AsyncCopy_; static constexpr index_t NumPrefetchK = NumPrefetchK_; static constexpr index_t NumPrefetchV = NumPrefetchK_; + static constexpr index_t NumKVLdsBuffers = max(NumPrefetchK, NumPrefetchV); + using QXPolicy = BlockFmhaPipelineQXCustomPolicy; template struct LdsBufferSequence { + static constexpr index_t num_lds_buffers_ = max(k_prefetches_, v_prefetches_); + static constexpr index_t ceil_ = ((v_loops_ - 1) / num_lds_buffers_) * num_lds_buffers_; + + // for qr_ks_vs_async, the Lds buffer assigned to last gemm_1 iteration of V should not + // overlap with the Lds buffers used by first two gemm_0 iterations of K static constexpr auto Make() { + // ensure v_loop_-1 is assigned to num_lds_buffers-1 return transform_sequences( [&](auto i) { if(i < k_loops_) - return i % k_prefetches_; - return (i - k_loops_) % v_prefetches_; + return i % num_lds_buffers_; + else + return ((num_lds_buffers_ - 1) + (i - k_loops_ + ceil_ - (v_loops_ - 1))) % + num_lds_buffers_; }, typename arithmetic_sequence_gen<0, k_loops_ + v_loops_, 1>::type{}); }; using type = remove_cvref_t; }; + // clang-format off template<> struct LdsBufferSequence<3, 3, 4, 4> { using type = sequence<1, 2, 0, 1, 0, 1, 2, 0>; }; @@ -357,13 +362,20 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - if constexpr(AsyncCopyK) + if constexpr(AsyncCopy) { return 4 / sizeof(KDataType); } else { - return 16 / sizeof(KDataType); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + + return min(MaxVectorSize, ElemPerThread); } } @@ -427,7 +439,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy().get_element_space_size(); } @@ -549,55 +561,6 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy - CK_TILE_HOST_DEVICE static constexpr auto - MakeKLdsLoadBlockDescriptor(number = number<0>{}) - { - // K is always k-major, we use async-copy to load into LDS - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; - constexpr index_t warpSize = ck_tile::get_warp_size(); - - constexpr index_t KPack = GetSmemKPackK(); // this is for lds - constexpr index_t KVector = GetAlignmentK(); // this is for global load - constexpr index_t kPad = KPack; // for async-copy, this pad is between warps - - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); - constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave - constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave - constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); - - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( - make_tuple(number{}, // n0 - number{}, // n2 - number{}, // n1 - number{}, // k0 - number{}), // k1 - make_tuple(number{}, - number{}, - number{}, - number{}, - number<1>{}), - number()>{}, - number{}, - number<1>{}); - - constexpr auto k_lds_block_desc = transform_tensor_descriptor( - k_lds_block_desc_0, - make_tuple( - make_merge_transform( - make_tuple(number{}, number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return k_lds_block_desc; - } -#else template CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor() { @@ -624,7 +587,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // max(SingleKSize, SingleVSize); constexpr auto k_lds_block_desc_0 = - make_naive_tensor_descriptor(make_tuple(number{}, // num_buffers + make_naive_tensor_descriptor(make_tuple(number{}, // num_buffers number{}, // n0 number{}, // n2 number{}, // n1 @@ -642,7 +605,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy{}, + make_merge_transform(make_tuple(number{}, number{}, number{}, number{})), @@ -652,7 +615,6 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy @@ -670,7 +632,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy{}, + make_tuple(number{}, number{}, number{}, number{}, @@ -687,7 +649,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy{}, number{}, number{})), + number{}, number{}, number{})), make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -703,14 +665,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy() * sizeof(typename Problem::KDataType); - return QXPolicy::template GetSmemSizeQ() + - single_smem_size * max(NumPrefetchK, NumPrefetchV); + return QXPolicy::template GetSmemSizeQ() + single_smem_size * NumKVLdsBuffers; } template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - if constexpr(AsyncCopyK) + if constexpr(AsyncCopy) { return GetSmemSizeKV() + GetSmemSizeDropout(0); } @@ -754,7 +715,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() { - if constexpr(!AsyncCopyK) + if constexpr(!AsyncCopy) { using KDataType = remove_cvref_t; @@ -762,7 +723,10 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy