diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 8690ba290c..d68270394e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -623,38 +623,40 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 1) // AMD buffer atomic operations require contiguous output layout - if(a.k_batch > 1) + if(kernel_arg.k_batch > 1) { if constexpr(std::is_same_v) { - if(a.StrideC != a.N) + if(kernel_arg.StrideC != kernel_arg.N) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[" << __func__ << "] group id: " << i - << " SplitK (k_batch=" << a.k_batch + << " SplitK (k_batch=" << kernel_arg.k_batch << ") requires contiguous output stride." << " For RowMajor layout: StrideC must equal N." - << " Got StrideC=" << a.StrideC << ", N=" << a.N << std::endl; + << " Got StrideC=" << kernel_arg.StrideC + << ", N=" << kernel_arg.N << std::endl; } return false; } } else if constexpr(std::is_same_v) { - if(a.StrideC != a.M) + if(kernel_arg.StrideC != kernel_arg.M) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[" << __func__ << "] group id: " << i - << " SplitK (k_batch=" << a.k_batch + << " SplitK (k_batch=" << kernel_arg.k_batch << ") requires contiguous output stride." << " For ColumnMajor layout: StrideC must equal M." - << " Got StrideC=" << a.StrideC << ", M=" << a.M << std::endl; + << " Got StrideC=" << kernel_arg.StrideC + << ", M=" << kernel_arg.M << std::endl; } return false; } @@ -666,7 +668,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 0) { - group_arg_valid = GridwiseGemm64::CheckValidity(a); + group_arg_valid = GridwiseGemm64::CheckValidity(kernel_arg); } } else @@ -674,7 +676,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 0) { group_arg_valid = GridwiseGemm32::CheckValidity( - reinterpret_cast(a)); + reinterpret_cast(kernel_arg)); } } @@ -684,7 +686,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK