diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index ec48beb789..1db9fd45b8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -620,7 +620,44 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 1) + // TODO: Enable splitK + if(a.k_batch > 1) + { + if constexpr(std::is_same_v) + { + if(a.StrideC != a.N) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " SplitK (k_batch=" << a.k_batch + << ") requires contiguous output stride." + << " For RowMajor layout: StrideC must equal N." + << " Got StrideC=" << a.StrideC << ", N=" << a.N << std::endl; + } + return false; + } + } + else if constexpr(std::is_same_v) + { + if(a.StrideC != a.M) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " SplitK (k_batch=" << a.k_batch + << ") requires contiguous output stride." + << " For ColumnMajor layout: StrideC must equal M." + << " Got StrideC=" << a.StrideC << ", M=" << a.M << std::endl; + } + return false; + } + } + } + bool group_arg_valid = false; if(isWave64) {