From 4a4df2657ad92f13394c33fb676cd502c62b9897 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 27 Aug 2025 12:35:42 +0200 Subject: [PATCH] Fix splitk autodeduce for grouped conv bwd weight (#2742) [ROCm/composable_kernel commit: cfe5e448dbf2d60ee22358e3d047600aca004090] --- .../device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 2 +- .../impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 95361287db..4565074b3e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -641,7 +641,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. const auto k_batch_max = static_cast((gemmK - 1) / KPerBlock); - k_batch_ = std::min(k_batch_, k_batch_max); + k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 6e74899706..0793285dbd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -506,7 +506,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. const auto k_batch_max = static_cast((gemmK - 1) / K0PerBlock); - k_batch_ = std::min(k_batch_, k_batch_max); + k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) {