From f1d644d4cdcc40378a2144fd19ae08da017596e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 30 Jul 2025 11:33:12 +0000 Subject: [PATCH] Calculate grid size for split-K autodeduction directly from input array shapes and template params. --- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 27 +++----------- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 36 ++++--------------- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 27 +++----------- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 33 +++-------------- .../gpu/device/impl/split_k_utils.hpp | 8 +++++ 5 files changed, 27 insertions(+), 104 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 0469f10f96..6558c10c95 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -645,31 +645,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle if(split_k < 0) { - constexpr int k_batch_initial = 1; - const auto descs_initial = - conv_to_gemm_transformer - .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - Conv_N_, - Conv_K_, - Conv_C_, - input_spatial_lengths_, - filter_spatial_lengths_, - output_spatial_lengths_, - b_g_n_c_wis_strides, - e_g_k_c_xs_strides, - a_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - k_batch_initial); - - const auto& ce_grid_desc_m_n = descs_initial[I2]; - const auto& block_2_ctile_map = GridwiseGemm::MakeCBlockClusterAdaptor( - ce_grid_desc_m_n, M01, N01, k_batch_initial); + ck::index_t gemmM, gemmN; + std::tie(gemmM, gemmN, std::ignore) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); const auto grid_size = - block_2_ctile_map.CalculateGridSize(ce_grid_desc_m_n) * Conv_G_; + calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index be81692181..3bf41ceab6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -629,41 +629,17 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle if(split_k < 0) { - constexpr int k_batch_initial = 1; - const auto descs_initial = - conv_to_gemm_transformer_v2 - .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - Conv_N_, - Conv_K_, - Conv_C_, - input_spatial_lengths_, - filter_spatial_lengths_, - output_spatial_lengths_, - b_g_n_c_wis_strides, - e_g_k_c_xs_strides, - a_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - k_batch_initial); + ck::index_t gemmM, gemmN, gemmK; + std::tie(gemmM, gemmN, gemmK) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); - const auto& a_grid_desc_kbatch_k0_m_k1 = descs_initial[I0]; - const auto& b_grid_desc_kbatch_k0_n_k1 = descs_initial[I1]; - const auto gemmM = a_grid_desc_kbatch_k0_m_k1.GetLength(I1); - const auto gemmN = b_grid_desc_kbatch_k0_n_k1.GetLength(I1); - - const auto grid_size = - GridwiseGemm::Block2CTileMap::CalculateGridSize(gemmM, gemmN) * Conv_G_ / - NumGroupsToMerge; + const auto grid_size = calculate_mn_grid_size(gemmM, gemmN) * + Conv_G_ / NumGroupsToMerge; k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); // Ensure that k_batch_ does not exceed the maximum value - // for the GEMM pipeline - ck::index_t gemmK; - std::tie(std::ignore, std::ignore, gemmK) = - get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); + // for the GEMM pipeline. const auto k_batch_max = static_cast((gemmK - 1) / KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index c2dc5418e7..49fea59aff 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -528,31 +528,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle if(split_k < 0) { - constexpr int k_batch_initial = 1; - const auto descs_initial = - conv_to_gemm_transformer - .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - Conv_N_, - Conv_K_, - Conv_C_, - input_spatial_lengths_, - filter_spatial_lengths_, - output_spatial_lengths_, - b_g_n_c_wis_strides_transposed, - e_g_k_c_xs_strides_transposed, - a_g_n_k_wos_strides_transposed, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - k_batch_initial); - - const auto& c_grid_desc_m_n = descs_initial[I2]; - const auto& block_2_ctile_map = GridwiseGemm::MakeCBlockClusterAdaptor( - c_grid_desc_m_n, M01, N01, k_batch_initial); + ck::index_t gemmM, gemmN; + std::tie(gemmM, gemmN, std::ignore) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); const auto grid_size = - block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * Conv_G_; + calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 167a8144f2..1bd203eba6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -494,40 +494,17 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 if(split_k < 0) { - constexpr int k_batch_initial = 1; - const auto descs_initial = - conv_to_gemm_transformer - .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - Conv_N_, - Conv_K_, - Conv_C_, - input_spatial_lengths_, - filter_spatial_lengths_, - output_spatial_lengths_, - b_g_n_c_wis_strides, - e_g_k_c_xs_strides, - a_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - k_batch_initial); - - const auto& a_grid_desc_kbatch_k0_m_k1 = descs_initial[I0]; - const auto& b_grid_desc_kbatch_k0_n_k1 = descs_initial[I1]; - const auto gemmM = a_grid_desc_kbatch_k0_m_k1.GetLength(I1); - const auto gemmN = b_grid_desc_kbatch_k0_n_k1.GetLength(I1); + ck::index_t gemmM, gemmN, gemmK; + std::tie(gemmM, gemmN, gemmK) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); const auto grid_size = - GridwiseGemm::Block2CTileMap::CalculateGridSize(gemmM, gemmN) * Conv_G_; + calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); // Ensure that k_batch_ does not exceed the maximum value - // for the GEMM pipeline - ck::index_t gemmK; - std::tie(std::ignore, std::ignore, gemmK) = - get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); + // for the GEMM pipeline. const auto k_batch_max = static_cast((gemmK - 1) / K0PerBlock); k_batch_ = std::min(k_batch_, k_batch_max); diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp index 99e0635bec..32179d179e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp @@ -80,6 +80,14 @@ get_bwd_weight_gemm_sizes(const std::array& a_g_n_k_wo return std::make_tuple(gemmM, gemmN, gemmK); } +template +inline ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN) +{ + const auto M0 = math::integer_divide_ceil(gemmM, MPerBlock); + const auto N0 = math::integer_divide_ceil(gemmN, NPerBlock); + return M0 * N0; +} + } // namespace device } // namespace tensor_operation } // namespace ck