diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index bfb567d1e0..a3eab579e7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -652,7 +652,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. - const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); + const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); // Cap k_batch_ to 128 to avoid accuracy issues diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 053f0eb3ae..87117be4ce 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -596,7 +596,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. - const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); + const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 99ec3387dc..0ee5ac3647 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -702,7 +702,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. - const auto k_batch_max = static_cast((gemmK - 1) / KPerBlock); + const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock); k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 2bce582f68..bfc88753a2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -537,7 +537,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. - const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); + const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); // Cap k_batch_ to 128 to avoid accuracy issues diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 2ab60581e7..dade0515af 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -607,7 +607,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. - const auto k_batch_max = static_cast((gemmK - 1) / K0PerBlock); + const auto k_batch_max = math::integer_divide_ceil(gemmK, K0PerBlock); k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); // Cap k_batch_ to 128 to avoid accuracy issues