Remove autodeduce 1 stage

This commit is contained in:
Enrico Degregori
2025-12-12 10:35:30 +00:00
parent 0f1bb0e817
commit a87256a676
2 changed files with 45 additions and 2 deletions

View File

@@ -583,7 +583,36 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
k_batch_ = split_k;
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
std::tie(gemmM, gemmN, gemmK) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size =
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
<< std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_
<< std::endl;
}
}
else
#endif
{
k_batch_ = split_k;
}
const auto descs =
conv_to_gemm_transformer
@@ -954,6 +983,13 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *

View File

@@ -506,7 +506,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset,
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
@@ -532,6 +532,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
}
}
else
#endif
{
k_batch_ = split_k;
}
@@ -1034,6 +1035,12 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *