Re-enable two stage kernel

This commit is contained in:
Graner, Johannes
2025-12-18 05:07:59 -05:00
parent 323e014799
commit adbfcad03b

View File

@@ -848,23 +848,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
return false;
}
// TODO: Fix this.
// Error appears in `script/profiler_grouped_gemm.sh grouped_gemm 1 0 1 1 0 0`
if(std::is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
std::is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
std::is_same<ELayout, tensor_layout::gemm::RowMajor>::value &&
getGemmSpecializationString(GemmSpec) == "MNKPadding" && arg.K_BATCH > 2)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout
<< "All RowMajor layout with MNKPadding specialization and KBatch > 2 is not "
"supported for all possible shapes!"
<< std::endl;
}
return false;
}
bool supported = true;
bool isWave64 = get_warp_size() == 64;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)