mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
use iglp to improve dim256 fmha fwd in qr_ks_vs pipeline (#2711)
* add k_lds padding and iglp to improve dim256 fmha fwd * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * update block_fmha_pipeline_qr_ks_vs.hpp Signed-off-by: JL-underdog <Jun.Lin@amd.com> * Update block_fmha_pipeline_qx_ks_vs_custom_policy.hpp * clang format Signed-off-by: JL-underdog <Jun.Lin@amd.com> * use same naming style --------- Signed-off-by: JL-underdog <Jun.Lin@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -58,6 +58,20 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
using BlockGemm0 = remove_cvref_t<decltype(Policy::template GetQKBlockGemm<Problem>())>;
|
||||
static constexpr auto WarpGemmConfig =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
|
||||
static constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>();
|
||||
static constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>();
|
||||
static constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM;
|
||||
static constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN;
|
||||
static constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK;
|
||||
static constexpr int NumMfmaInsts =
|
||||
(kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) * (kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp);
|
||||
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
|
||||
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
!kHasLogitsSoftCap)) ||
|
||||
@@ -281,6 +295,22 @@ struct BlockFmhaPipelineQRKSVS
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
// Use compile-time conditional for group barrier sequence
|
||||
// (No runtime lambda selection)
|
||||
auto schedule_gemm0 = [] {
|
||||
if constexpr(kQKHeaddim == 256)
|
||||
{
|
||||
static_assert(NumMfmaInsts % 8 == 0);
|
||||
static_for<0, NumMfmaInsts / 8, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(DS_READ, 2, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(MFMA, 4, 0); // MFMA
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(2 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
@@ -323,6 +353,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
schedule_gemm0();
|
||||
block_sync_lds();
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
@@ -341,6 +372,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kM0, (k0_loops - 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
schedule_gemm0();
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
@@ -351,6 +383,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_lds_window);
|
||||
schedule_gemm0();
|
||||
}
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
|
||||
Reference in New Issue
Block a user