Merge commit 'bf7b458e6ebaafaac3867b0af468f87d978757ae' into develop

This commit is contained in:
assistant-librarian[bot]
2025-08-28 04:13:22 +00:00
parent d279e73d10
commit e121136b89

View File

@@ -58,6 +58,20 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
using BlockGemm0 = remove_cvref_t<decltype(Policy::template GetQKBlockGemm<Problem>())>;
static constexpr auto WarpGemmConfig =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
static constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>();
static constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>();
static constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM;
static constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN;
static constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK;
static constexpr int NumMfmaInsts =
(kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) * (kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp);
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
@@ -281,6 +295,22 @@ struct BlockFmhaPipelineQRKSVS
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
// Use compile-time conditional for group barrier sequence
// (No runtime lambda selection)
auto schedule_gemm0 = [] {
if constexpr(kQKHeaddim == 256)
{
static_assert(NumMfmaInsts % 8 == 0);
static_for<0, NumMfmaInsts / 8, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(DS_READ, 2, 0); // DS read
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(MFMA, 4, 0); // MFMA
});
}
};
static_assert(2 <= k0_loops);
static_assert(1 <= k1_loops);
@@ -323,6 +353,7 @@ struct BlockFmhaPipelineQRKSVS
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window);
schedule_gemm0();
block_sync_lds();
move_tile_window(k_dram_window, {0, kK0});
@@ -341,6 +372,7 @@ struct BlockFmhaPipelineQRKSVS
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kM0, (k0_loops - 1) * kK0>{}),
k_lds_window);
schedule_gemm0();
block_sync_lds();
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
@@ -351,6 +383,7 @@ struct BlockFmhaPipelineQRKSVS
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_lds_window);
schedule_gemm0();
}
// STAGE 2, scale_s, add bias, mask, softmax