use iglp to improve dim256 fmha fwd in qr_ks_vs pipeline (#2711)

* add k_lds padding and iglp to improve dim256 fmha fwd

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* update  block_fmha_pipeline_qr_ks_vs.hpp

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* Update block_fmha_pipeline_qx_ks_vs_custom_policy.hpp

* clang format

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* use same naming style

---------

Signed-off-by: JL-underdog <Jun.Lin@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Linjun-AMD
2025-08-28 11:39:39 +08:00
committed by GitHub
parent f5f795c4d6
commit bf7b458e6e

View File

@@ -58,6 +58,20 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
using BlockGemm0 = remove_cvref_t<decltype(Policy::template GetQKBlockGemm<Problem>())>;
static constexpr auto WarpGemmConfig =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
static constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>();
static constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>();
static constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM;
static constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN;
static constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK;
static constexpr int NumMfmaInsts =
(kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) * (kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp);
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
@@ -281,6 +295,22 @@ struct BlockFmhaPipelineQRKSVS
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
// Use compile-time conditional for group barrier sequence
// (No runtime lambda selection)
auto schedule_gemm0 = [] {
if constexpr(kQKHeaddim == 256)
{
static_assert(NumMfmaInsts % 8 == 0);
static_for<0, NumMfmaInsts / 8, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(DS_READ, 2, 0); // DS read
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(MFMA, 4, 0); // MFMA
});
}
};
static_assert(2 <= k0_loops);
static_assert(1 <= k1_loops);
@@ -323,6 +353,7 @@ struct BlockFmhaPipelineQRKSVS
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window);
schedule_gemm0();
block_sync_lds();
move_tile_window(k_dram_window, {0, kK0});
@@ -341,6 +372,7 @@ struct BlockFmhaPipelineQRKSVS
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kM0, (k0_loops - 1) * kK0>{}),
k_lds_window);
schedule_gemm0();
block_sync_lds();
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
@@ -351,6 +383,7 @@ struct BlockFmhaPipelineQRKSVS
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_lds_window);
schedule_gemm0();
}
// STAGE 2, scale_s, add bias, mask, softmax