From 48e0d510f0039492ffb4171c2eba997b0709ee89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 27 Mar 2026 16:37:37 +0100 Subject: [PATCH] [CK] Fix min k_batch calculation in conv kernels (#5785) ## Motivation Avoid division by 0 and remove not needed "-1". ## Technical Details Our div up implementation return lower value if input is divisible. There is no need to subtract 1. ## Test Plan test_grouped_conv_bwd_weight ## Test Result Passed locally. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-1019 --- ...vice_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 2 +- ...evice_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 2 +- .../device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 2 +- .../impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 2 +- .../impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index bfb567d1e0..a3eab579e7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -652,7 +652,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. - const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); + const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); // Cap k_batch_ to 128 to avoid accuracy issues diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 053f0eb3ae..87117be4ce 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -596,7 +596,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. - const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); + const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 99ec3387dc..0ee5ac3647 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -702,7 +702,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. - const auto k_batch_max = static_cast((gemmK - 1) / KPerBlock); + const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock); k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 2bce582f68..bfc88753a2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -537,7 +537,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. - const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); + const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); // Cap k_batch_ to 128 to avoid accuracy issues diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 2ab60581e7..dade0515af 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -607,7 +607,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. - const auto k_batch_max = static_cast((gemmK - 1) / K0PerBlock); + const auto k_batch_max = math::integer_divide_ceil(gemmK, K0PerBlock); k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); // Cap k_batch_ to 128 to avoid accuracy issues