Merge commit '3143a5a480e4fcf216670012fe491b44324f03b6' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-15 07:16:25 +00:00
parent 669906c786
commit 6164d076de

View File

@@ -620,7 +620,44 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool isWave64 = get_warp_size() == 64;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
const auto& a = arg.gemm_kernel_args_[i].karg_;
// Validate stride requirements for SplitK (k_batch > 1)
// TODO: Enable splitK
if(a.k_batch > 1)
{
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
if(a.StrideC != a.N)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " SplitK (k_batch=" << a.k_batch
<< ") requires contiguous output stride."
<< " For RowMajor layout: StrideC must equal N."
<< " Got StrideC=" << a.StrideC << ", N=" << a.N << std::endl;
}
return false;
}
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
if(a.StrideC != a.M)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " SplitK (k_batch=" << a.k_batch
<< ") requires contiguous output stride."
<< " For ColumnMajor layout: StrideC must equal M."
<< " Got StrideC=" << a.StrideC << ", M=" << a.M << std::endl;
}
return false;
}
}
}
bool group_arg_valid = false;
if(isWave64)
{