From e7f8544bcdc7403adfcb484f5d9d2b9e2bb82821 Mon Sep 17 00:00:00 2001 From: letaoqin Date: Thu, 6 Mar 2025 08:24:37 +0000 Subject: [PATCH] add limiting for N0/NWave --- ...ise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 31b07d9e4b..20cf6d79da 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -923,20 +923,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); - // should remove kpading - if((GemmSpec == tensor_operation::device::GemmSpecialization::KPadding) && - ((karg.BK0Shuffled % karg.KBatch) != 0)) - { - return false; - } - // for not adding k padd operator - if(CalculateBKShufflePadded(karg.K) % KPerBlock != 0) + if((CalculateBKShufflePadded(karg.K) % KPerBlock != 0) || + (karg.BK0Shuffled % karg.KBatch != 0)) { return false; } - if(karg.N % NPerXdl != 0) + if((karg.N % NPerXdl != 0) || (karg.BN0Shuffled % NWave != 0)) { return false; }