From 39bb48ed2e6a0041b77e29b11b536bf77e4b2f80 Mon Sep 17 00:00:00 2001 From: Johannes Graner Date: Mon, 15 Dec 2025 08:03:00 +0100 Subject: [PATCH] [CK Grouped Gemm] Disable split-k kernel for split-k > 1 with non-contiguous strides (#3405) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Disable kernel for split-k > 1 with non-contiguous strides * Update device_grouped_gemm_xdl_splitk_cshuffle.hpp --------- AICK-441 (partial) Co-authored-by: Bartłomiej Kocot Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> [ROCm/composable_kernel commit: 3143a5a480e4fcf216670012fe491b44324f03b6] --- ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index ec48beb789..1db9fd45b8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -620,7 +620,44 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 1) + // TODO: Enable splitK + if(a.k_batch > 1) + { + if constexpr(std::is_same_v) + { + if(a.StrideC != a.N) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " SplitK (k_batch=" << a.k_batch + << ") requires contiguous output stride." + << " For RowMajor layout: StrideC must equal N." + << " Got StrideC=" << a.StrideC << ", N=" << a.N << std::endl; + } + return false; + } + } + else if constexpr(std::is_same_v) + { + if(a.StrideC != a.M) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " SplitK (k_batch=" << a.k_batch + << ") requires contiguous output stride." + << " For ColumnMajor layout: StrideC must equal M." + << " Got StrideC=" << a.StrideC << ", M=" << a.M << std::endl; + } + return false; + } + } + } + bool group_arg_valid = false; if(isWave64) {