From 4e92d44cd72fdab18d15e3f92d6548f46fac9de5 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 14 Feb 2025 06:11:01 +0000 Subject: [PATCH] tempsave --- ...block_fmha_bwd_pipeline_default_policy.hpp | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 2163a93d7c..3a209eccf0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1749,27 +1749,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t MFMA_INST = Gemm3MFMA; // To hide instruction issue latency - constexpr index_t LDS_WRITE_PER_MFMA = - LDS_WRITE_INST / MFMA_INST >= 1 ? LDS_WRITE_INST / MFMA_INST : 1; + constexpr index_t LDS_WRITE_PER_MFMA = ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST); constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / LDS_WRITE_PER_MFMA; - constexpr index_t LDS_READ_PER_MFMA = - (MFMA_INST - MFMA_INST_LDS_WRITE) > 0 - ? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE) > 0 - ? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE) - : 1 - : 0; + constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, (MFMA_INST - MFMA_INST_LDS_WRITE)); static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { - ignore = i; __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write + if constexpr (i * LDS_WRITE_PER_MFMA < LDS_WRITE_INST){ + __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write + } + else if constexpr ( (i +1 ) * LDS_WRITE_PER_MFMA > LDS_WRITE_INST){ + __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_INST - i * LDS_WRITE_PER_MFMA, 0); // DS Write + } }); static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { - ignore = i; __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read + if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){ + __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read + } + else if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){ + __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read + } }); } @@ -1782,7 +1784,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t MFMA_INST = Gemm4MFMA; // To hide instruction issue latency - constexpr index_t LDS_READ_PER_MFMA = + constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST > 0 ? LDS_READ_INST / MFMA_INST : 1; static_for<0, MFMA_INST, 1>{}([&](auto i) {