mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Calculate grid size for split-K autodeduction directly from input array shapes and template params.
This commit is contained in:
@@ -645,31 +645,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
|
||||
if(split_k < 0)
|
||||
{
|
||||
constexpr int k_batch_initial = 1;
|
||||
const auto descs_initial =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides,
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_initial);
|
||||
|
||||
const auto& ce_grid_desc_m_n = descs_initial[I2];
|
||||
const auto& block_2_ctile_map = GridwiseGemm::MakeCBlockClusterAdaptor(
|
||||
ce_grid_desc_m_n, M01, N01, k_batch_initial);
|
||||
ck::index_t gemmM, gemmN;
|
||||
std::tie(gemmM, gemmN, std::ignore) =
|
||||
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
|
||||
|
||||
const auto grid_size =
|
||||
block_2_ctile_map.CalculateGridSize(ce_grid_desc_m_n) * Conv_G_;
|
||||
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
|
||||
grid_size);
|
||||
}
|
||||
|
||||
@@ -629,41 +629,17 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
|
||||
if(split_k < 0)
|
||||
{
|
||||
constexpr int k_batch_initial = 1;
|
||||
const auto descs_initial =
|
||||
conv_to_gemm_transformer_v2
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides,
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_initial);
|
||||
ck::index_t gemmM, gemmN, gemmK;
|
||||
std::tie(gemmM, gemmN, gemmK) =
|
||||
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
|
||||
|
||||
const auto& a_grid_desc_kbatch_k0_m_k1 = descs_initial[I0];
|
||||
const auto& b_grid_desc_kbatch_k0_n_k1 = descs_initial[I1];
|
||||
const auto gemmM = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
|
||||
const auto gemmN = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
|
||||
|
||||
const auto grid_size =
|
||||
GridwiseGemm::Block2CTileMap::CalculateGridSize(gemmM, gemmN) * Conv_G_ /
|
||||
NumGroupsToMerge;
|
||||
const auto grid_size = calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) *
|
||||
Conv_G_ / NumGroupsToMerge;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
|
||||
grid_size);
|
||||
|
||||
// Ensure that k_batch_ does not exceed the maximum value
|
||||
// for the GEMM pipeline
|
||||
ck::index_t gemmK;
|
||||
std::tie(std::ignore, std::ignore, gemmK) =
|
||||
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
|
||||
// for the GEMM pipeline.
|
||||
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / KPerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
|
||||
|
||||
@@ -528,31 +528,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
if(split_k < 0)
|
||||
{
|
||||
constexpr int k_batch_initial = 1;
|
||||
const auto descs_initial =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides_transposed,
|
||||
e_g_k_c_xs_strides_transposed,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_initial);
|
||||
|
||||
const auto& c_grid_desc_m_n = descs_initial[I2];
|
||||
const auto& block_2_ctile_map = GridwiseGemm::MakeCBlockClusterAdaptor(
|
||||
c_grid_desc_m_n, M01, N01, k_batch_initial);
|
||||
ck::index_t gemmM, gemmN;
|
||||
std::tie(gemmM, gemmN, std::ignore) =
|
||||
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
|
||||
|
||||
const auto grid_size =
|
||||
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * Conv_G_;
|
||||
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
|
||||
grid_size);
|
||||
}
|
||||
|
||||
@@ -494,40 +494,17 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
|
||||
if(split_k < 0)
|
||||
{
|
||||
constexpr int k_batch_initial = 1;
|
||||
const auto descs_initial =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides,
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_initial);
|
||||
|
||||
const auto& a_grid_desc_kbatch_k0_m_k1 = descs_initial[I0];
|
||||
const auto& b_grid_desc_kbatch_k0_n_k1 = descs_initial[I1];
|
||||
const auto gemmM = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
|
||||
const auto gemmN = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
|
||||
ck::index_t gemmM, gemmN, gemmK;
|
||||
std::tie(gemmM, gemmN, gemmK) =
|
||||
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
|
||||
|
||||
const auto grid_size =
|
||||
GridwiseGemm::Block2CTileMap::CalculateGridSize(gemmM, gemmN) * Conv_G_;
|
||||
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
|
||||
grid_size);
|
||||
|
||||
// Ensure that k_batch_ does not exceed the maximum value
|
||||
// for the GEMM pipeline
|
||||
ck::index_t gemmK;
|
||||
std::tie(std::ignore, std::ignore, gemmK) =
|
||||
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
|
||||
// for the GEMM pipeline.
|
||||
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
|
||||
|
||||
@@ -80,6 +80,14 @@ get_bwd_weight_gemm_sizes(const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wo
|
||||
return std::make_tuple(gemmM, gemmN, gemmK);
|
||||
}
|
||||
|
||||
template <ck::index_t MPerBlock, ck::index_t NPerBlock>
|
||||
inline ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(gemmM, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(gemmN, NPerBlock);
|
||||
return M0 * N0;
|
||||
}
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user