diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 95361287db..4565074b3e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -641,7 +641,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. const auto k_batch_max = static_cast((gemmK - 1) / KPerBlock); - k_batch_ = std::min(k_batch_, k_batch_max); + k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 6e74899706..0793285dbd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -506,7 +506,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. const auto k_batch_max = static_cast((gemmK - 1) / K0PerBlock); - k_batch_ = std::min(k_batch_, k_batch_max); + k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) {