diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp index 9f91c06e8e..d01c1d4936 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp @@ -248,6 +248,14 @@ struct BlockGemmARegBRegCRegEightWavesV1 merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); + if constexpr(nIter == 0 && mIter == MIterPerWarp - 1 && kIter == 0) + { + s_nop(); + s_waitcnt_lgkm<4>(); + __builtin_amdgcn_sched_group_barrier(0x004, 1, 0); + __builtin_amdgcn_sched_barrier(0); + } }); }); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp index 04112c6357..2b4e2ed849 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp @@ -183,12 +183,6 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp // Hot loop scheduler // ------------------ auto hot_loop_scheduler = [&]() { - __builtin_amdgcn_sched_group_barrier(0x008, MIterPerWarp, 0); // MFMA - s_waitcnt_lgkm<4>(); - __builtin_amdgcn_sched_group_barrier(0x004, 1, 0); // lgkmcnt / SALU - static_for<0, MFMA_INST - MIterPerWarp, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); __builtin_amdgcn_sched_barrier(0); }; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp index 823c4eef32..d8a54dfc86 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp @@ -383,7 +383,6 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< auto calc_gemm = [&](index_t i) { __builtin_amdgcn_sched_barrier(0); - s_nop(); block_gemm( c_block_tile, a_block_tile, b_block_tiles, aq_block_tile[i], bq_block_tile[i]); scheduler_func(); @@ -392,6 +391,7 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< auto main_body = [&](auto tic, auto toc) { __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); calc_gemm(tic); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_eight_waves.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_eight_waves.hpp index 61e3a00fd9..641af284ba 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_eight_waves.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_eight_waves.hpp @@ -255,6 +255,8 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase "C block tensor data type!"); constexpr auto warp_size = get_warp_size(); + s_nop(); + auto q_block_tensor = aq_block_tensor; if constexpr(Traits::NQPerBlock / NWarp == 1) {