diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 31b07d9e4b..20cf6d79da 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -923,20 +923,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); - // should remove kpading - if((GemmSpec == tensor_operation::device::GemmSpecialization::KPadding) && - ((karg.BK0Shuffled % karg.KBatch) != 0)) - { - return false; - } - // for not adding k padd operator - if(CalculateBKShufflePadded(karg.K) % KPerBlock != 0) + if((CalculateBKShufflePadded(karg.K) % KPerBlock != 0) || + (karg.BK0Shuffled % karg.KBatch != 0)) { return false; } - if(karg.N % NPerXdl != 0) + if((karg.N % NPerXdl != 0) || (karg.BN0Shuffled % NWave != 0)) { return false; }