[rocm-libraries] ROCm/rocm-libraries#5785 (commit d8ecfc1)

[CK] Fix min k_batch calculation in conv kernels

## Motivation

Avoid division by 0 and remove not needed "-1".

## Technical Details

Our div up implementation return lower value if input is divisible.
There is no need to subtract 1.

## Test Plan

test_grouped_conv_bwd_weight

## Test Result

Passed locally.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

AICK-1019
This commit is contained in:
Bartłomiej Kocot
2026-03-27 15:38:21 +00:00
committed by assistant-librarian[bot]
parent 4c926497ad
commit c28d0033d7
5 changed files with 5 additions and 5 deletions

View File

@@ -652,7 +652,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
// Cap k_batch_ to 128 to avoid accuracy issues

View File

@@ -596,7 +596,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))

View File

@@ -702,7 +702,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / KPerBlock);
const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock);
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))

View File

@@ -537,7 +537,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
// Cap k_batch_ to 128 to avoid accuracy issues

View File

@@ -607,7 +607,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
const auto k_batch_max = math::integer_divide_ceil(gemmK, K0PerBlock);
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
// Cap k_batch_ to 128 to avoid accuracy issues