mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
Fix splitk autodeduce for grouped conv bwd weight (#2742)
This commit is contained in:
@@ -641,7 +641,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
// Ensure that k_batch_ does not exceed the maximum value
|
||||
// for the GEMM pipeline.
|
||||
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / KPerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
|
||||
@@ -506,7 +506,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
// Ensure that k_batch_ does not exceed the maximum value
|
||||
// for the GEMM pipeline.
|
||||
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user