mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Rename single-character variable
This commit is contained in:
@@ -623,38 +623,40 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
bool isWave64 = get_warp_size() == 64;
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
const auto& kernel_arg = arg.gemm_kernel_args_[i].karg_;
|
||||
|
||||
// Validate stride requirements for SplitK (k_batch > 1)
|
||||
// AMD buffer atomic operations require contiguous output layout
|
||||
if(a.k_batch > 1)
|
||||
if(kernel_arg.k_batch > 1)
|
||||
{
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(a.StrideC != a.N)
|
||||
if(kernel_arg.StrideC != kernel_arg.N)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " SplitK (k_batch=" << a.k_batch
|
||||
<< " SplitK (k_batch=" << kernel_arg.k_batch
|
||||
<< ") requires contiguous output stride."
|
||||
<< " For RowMajor layout: StrideC must equal N."
|
||||
<< " Got StrideC=" << a.StrideC << ", N=" << a.N << std::endl;
|
||||
<< " Got StrideC=" << kernel_arg.StrideC
|
||||
<< ", N=" << kernel_arg.N << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if(a.StrideC != a.M)
|
||||
if(kernel_arg.StrideC != kernel_arg.M)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " SplitK (k_batch=" << a.k_batch
|
||||
<< " SplitK (k_batch=" << kernel_arg.k_batch
|
||||
<< ") requires contiguous output stride."
|
||||
<< " For ColumnMajor layout: StrideC must equal M."
|
||||
<< " Got StrideC=" << a.StrideC << ", M=" << a.M << std::endl;
|
||||
<< " Got StrideC=" << kernel_arg.StrideC
|
||||
<< ", M=" << kernel_arg.M << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -666,7 +668,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
{
|
||||
group_arg_valid = GridwiseGemm64::CheckValidity(a);
|
||||
group_arg_valid = GridwiseGemm64::CheckValidity(kernel_arg);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -674,7 +676,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
{
|
||||
group_arg_valid = GridwiseGemm32::CheckValidity(
|
||||
reinterpret_cast<const typename GridwiseGemm32::Argument&>(a));
|
||||
reinterpret_cast<const typename GridwiseGemm32::Argument&>(kernel_arg));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -684,7 +686,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
a.Print();
|
||||
kernel_arg.Print();
|
||||
}
|
||||
}
|
||||
supported = supported && group_arg_valid;
|
||||
|
||||
Reference in New Issue
Block a user