diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 3a209eccf0..0f21f77992 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1757,20 +1757,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA if constexpr (i * LDS_WRITE_PER_MFMA < LDS_WRITE_INST){ - __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write - } - else if constexpr ( (i +1 ) * LDS_WRITE_PER_MFMA > LDS_WRITE_INST){ - __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_INST - i * LDS_WRITE_PER_MFMA, 0); // DS Write + if constexpr ( (i +1 ) * LDS_WRITE_PER_MFMA > LDS_WRITE_INST){ + __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_INST - i * LDS_WRITE_PER_MFMA, 0); // DS Write + } + else{ + __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write + } } }); static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){ - __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read - } - else if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){ - __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read + if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){ + __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read + } + else{ + __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read + } } }); } @@ -1784,13 +1788,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t MFMA_INST = Gemm4MFMA; // To hide instruction issue latency - constexpr index_t LDS_READ_PER_MFMA = - LDS_READ_INST / MFMA_INST > 0 ? LDS_READ_INST / MFMA_INST : 1; + constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST); static_for<0, MFMA_INST, 1>{}([&](auto i) { - ignore = i; __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read + if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){ + if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){ + __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read + } + else{ + __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read + } + } }); } @@ -1843,11 +1852,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy static constexpr index_t QT_LDS_READ = kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ(); static constexpr index_t SGradT_LDS_READ_P1 = - kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); + // kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); + kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / 2; static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ(); static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t SGradT_LDS_READ_P2 = - kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); + // kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); + kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / 2; static constexpr index_t OGrad_LDS_READ = kM0 * kK2 / kBlockSize / GetAlignmentOGrad(); static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);