add limiting for N0/NWave

This commit is contained in:
letaoqin
2025-03-06 08:24:37 +00:00
parent d348c3fa4c
commit e7f8544bcd

View File

@@ -923,20 +923,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
// should remove kpading
if((GemmSpec == tensor_operation::device::GemmSpecialization::KPadding) &&
((karg.BK0Shuffled % karg.KBatch) != 0))
{
return false;
}
// for not adding k padd operator
if(CalculateBKShufflePadded(karg.K) % KPerBlock != 0)
if((CalculateBKShufflePadded(karg.K) % KPerBlock != 0) ||
(karg.BK0Shuffled % karg.KBatch != 0))
{
return false;
}
if(karg.N % NPerXdl != 0)
if((karg.N % NPerXdl != 0) || (karg.BN0Shuffled % NWave != 0))
{
return false;
}