[CK Grouped Gemm] Disable split-k kernel for split-k > 1 with non-contiguous strides (#3405)

* Disable kernel for split-k > 1 with non-contiguous strides

* Update device_grouped_gemm_xdl_splitk_cshuffle.hpp

---------

AICK-441 (partial)

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Johannes Graner
2025-12-15 08:03:00 +01:00
committed by GitHub
parent f5573f56d9
commit 3143a5a480

View File

@@ -620,7 +620,44 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool isWave64 = get_warp_size() == 64;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
const auto& a = arg.gemm_kernel_args_[i].karg_;
// Validate stride requirements for SplitK (k_batch > 1)
// TODO: Enable splitK
if(a.k_batch > 1)
{
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
if(a.StrideC != a.N)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " SplitK (k_batch=" << a.k_batch
<< ") requires contiguous output stride."
<< " For RowMajor layout: StrideC must equal N."
<< " Got StrideC=" << a.StrideC << ", N=" << a.N << std::endl;
}
return false;
}
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
if(a.StrideC != a.M)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " SplitK (k_batch=" << a.k_batch
<< ") requires contiguous output stride."
<< " For ColumnMajor layout: StrideC must equal M."
<< " Got StrideC=" << a.StrideC << ", M=" << a.M << std::endl;
}
return false;
}
}
}
bool group_arg_valid = false;
if(isWave64)
{