diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 29f183c613..f8d9973918 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -58,6 +58,20 @@ struct BlockFmhaPipelineQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; + using BlockGemm0 = remove_cvref_t())>; + static constexpr auto WarpGemmConfig = + BlockGemm0::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm0 = remove_cvref_t())>; + static constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>(); + static constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>(); + static constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM; + static constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN; + static constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK; + static constexpr int NumMfmaInsts = + (kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) * (kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp); + static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read + static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate) + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || !kHasLogitsSoftCap)) || @@ -281,6 +295,22 @@ struct BlockFmhaPipelineQRKSVS index_t i_total_loops = 0; constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; + // Use compile-time conditional for group barrier sequence + // (No runtime lambda selection) + auto schedule_gemm0 = [] { + if constexpr(kQKHeaddim == 256) + { + static_assert(NumMfmaInsts % 8 == 0); + static_for<0, NumMfmaInsts / 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(DS_READ, 2, 0); // DS read + __builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(MFMA, 4, 0); // MFMA + }); + } + }; static_assert(2 <= k0_loops); static_assert(1 <= k1_loops); @@ -323,6 +353,7 @@ struct BlockFmhaPipelineQRKSVS sequence<0, i_k0 * kK0>{}, sequence{}), k_lds_window); + schedule_gemm0(); block_sync_lds(); move_tile_window(k_dram_window, {0, kK0}); @@ -341,6 +372,7 @@ struct BlockFmhaPipelineQRKSVS sequence<0, (k0_loops - 2) * kK0>{}, sequence{}), k_lds_window); + schedule_gemm0(); block_sync_lds(); store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); @@ -351,6 +383,7 @@ struct BlockFmhaPipelineQRKSVS sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), k_lds_window); + schedule_gemm0(); } // STAGE 2, scale_s, add bias, mask, softmax