mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[rocm-libraries] ROCm/rocm-libraries#5785 (commit d8ecfc1)
[CK] Fix min k_batch calculation in conv kernels ## Motivation Avoid division by 0 and remove not needed "-1". ## Technical Details Our div up implementation return lower value if input is divisible. There is no need to subtract 1. ## Test Plan test_grouped_conv_bwd_weight ## Test Result Passed locally. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-1019
This commit is contained in:
committed by
assistant-librarian[bot]
parent
4c926497ad
commit
c28d0033d7
@@ -652,7 +652,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
|
||||
// Ensure that k_batch_ does not exceed the maximum value
|
||||
// for the GEMM pipeline.
|
||||
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
|
||||
const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
|
||||
@@ -596,7 +596,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
|
||||
|
||||
// Ensure that k_batch_ does not exceed the maximum value
|
||||
// for the GEMM pipeline.
|
||||
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
|
||||
const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
|
||||
@@ -702,7 +702,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
|
||||
// Ensure that k_batch_ does not exceed the maximum value
|
||||
// for the GEMM pipeline.
|
||||
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / KPerBlock);
|
||||
const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock);
|
||||
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
|
||||
@@ -537,7 +537,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
|
||||
// Ensure that k_batch_ does not exceed the maximum value
|
||||
// for the GEMM pipeline.
|
||||
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
|
||||
const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
|
||||
@@ -607,7 +607,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
|
||||
// Ensure that k_batch_ does not exceed the maximum value
|
||||
// for the GEMM pipeline.
|
||||
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
|
||||
const auto k_batch_max = math::integer_divide_ceil(gemmK, K0PerBlock);
|
||||
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
|
||||
Reference in New Issue
Block a user