Fix splitk autodeduce for grouped conv bwd weight (#2742)

This commit is contained in:
Bartłomiej Kocot
2025-08-27 12:35:42 +02:00
committed by GitHub
parent 245467f359
commit cfe5e448db
2 changed files with 2 additions and 2 deletions

View File

@@ -641,7 +641,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{

View File

@@ -506,7 +506,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{