diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp index 7a565fbaa7..c5307c4d7a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp @@ -141,6 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1{}([&](auto i) { ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(MPerBlock >= 128 && NPerBlock >= 128) + { + __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0); + } + else + { + __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0); + } __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read }); @@ -203,11 +211,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read - }); + static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}( + [&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read + }); } template {}); static constexpr auto xdlops_gemm = XdlopsGemm{};