Calculate grid size for split-K autodeduction directly from input array shapes and template params.

This commit is contained in:
Ville Pietilä
2025-07-30 11:33:12 +00:00
parent 94f7b441f2
commit f1d644d4cd
5 changed files with 27 additions and 104 deletions

View File

@@ -645,31 +645,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
if(split_k < 0)
{
constexpr int k_batch_initial = 1;
const auto descs_initial =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides,
e_g_k_c_xs_strides,
a_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_initial);
const auto& ce_grid_desc_m_n = descs_initial[I2];
const auto& block_2_ctile_map = GridwiseGemm::MakeCBlockClusterAdaptor(
ce_grid_desc_m_n, M01, N01, k_batch_initial);
ck::index_t gemmM, gemmN;
std::tie(gemmM, gemmN, std::ignore) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size =
block_2_ctile_map.CalculateGridSize(ce_grid_desc_m_n) * Conv_G_;
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
}

View File

@@ -629,41 +629,17 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
if(split_k < 0)
{
constexpr int k_batch_initial = 1;
const auto descs_initial =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides,
e_g_k_c_xs_strides,
a_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_initial);
ck::index_t gemmM, gemmN, gemmK;
std::tie(gemmM, gemmN, gemmK) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto& a_grid_desc_kbatch_k0_m_k1 = descs_initial[I0];
const auto& b_grid_desc_kbatch_k0_n_k1 = descs_initial[I1];
const auto gemmM = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto gemmN = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
const auto grid_size =
GridwiseGemm::Block2CTileMap::CalculateGridSize(gemmM, gemmN) * Conv_G_ /
NumGroupsToMerge;
const auto grid_size = calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) *
Conv_G_ / NumGroupsToMerge;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline
ck::index_t gemmK;
std::tie(std::ignore, std::ignore, gemmK) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
// for the GEMM pipeline.
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);

View File

@@ -528,31 +528,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
if(split_k < 0)
{
constexpr int k_batch_initial = 1;
const auto descs_initial =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides_transposed,
a_g_n_k_wos_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_initial);
const auto& c_grid_desc_m_n = descs_initial[I2];
const auto& block_2_ctile_map = GridwiseGemm::MakeCBlockClusterAdaptor(
c_grid_desc_m_n, M01, N01, k_batch_initial);
ck::index_t gemmM, gemmN;
std::tie(gemmM, gemmN, std::ignore) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size =
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * Conv_G_;
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
}

View File

@@ -494,40 +494,17 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
if(split_k < 0)
{
constexpr int k_batch_initial = 1;
const auto descs_initial =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides,
e_g_k_c_xs_strides,
a_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_initial);
const auto& a_grid_desc_kbatch_k0_m_k1 = descs_initial[I0];
const auto& b_grid_desc_kbatch_k0_n_k1 = descs_initial[I1];
const auto gemmM = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto gemmN = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
ck::index_t gemmM, gemmN, gemmK;
std::tie(gemmM, gemmN, gemmK) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size =
GridwiseGemm::Block2CTileMap::CalculateGridSize(gemmM, gemmN) * Conv_G_;
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline
ck::index_t gemmK;
std::tie(std::ignore, std::ignore, gemmK) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
// for the GEMM pipeline.
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);

View File

@@ -80,6 +80,14 @@ get_bwd_weight_gemm_sizes(const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wo
return std::make_tuple(gemmM, gemmN, gemmK);
}
template <ck::index_t MPerBlock, ck::index_t NPerBlock>
inline ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
{
const auto M0 = math::integer_divide_ceil(gemmM, MPerBlock);
const auto N0 = math::integer_divide_ceil(gemmN, NPerBlock);
return M0 * N0;
}
} // namespace device
} // namespace tensor_operation
} // namespace ck