From 62b08dd1bbf79d21ee91f08152b79f7a21bd9aa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 9 Jul 2025 09:22:49 -0500 Subject: [PATCH] Split-K autodeduction for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle. --- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 110 +++++++++++++++++- 1 file changed, 106 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index c7c463f43d..9602fbaee0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -22,6 +22,8 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -504,7 +506,57 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)); - struct Argument : public BaseArgument + struct MaximumActiveBlocksPerMultiprocessor + { + MaximumActiveBlocksPerMultiprocessor() + { + constexpr int dynamic_smem_size = 0; + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + int max_occupancy = 0; + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + value_ = std::max(1, max_occupancy); + } + int value_; + }; + + struct Argument : public BaseArgument, public ArgumentSplitK { Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, @@ -547,9 +599,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads}, - k_batch_{split_k} + input_right_pads_{input_right_pads} { + static MaximumActiveBlocksPerMultiprocessor max_occupancy; + c_space_size_bytes = ck::accumulate_n( e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * @@ -576,6 +629,56 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + if (split_k < 0) + { + constexpr int k_batch_initial = 1; + const auto descs_initial = + conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_initial); + + const auto& a_grid_desc_kbatch_k0_m_k1 = descs_initial[I0]; + const auto& b_grid_desc_kbatch_k0_n_k1 = descs_initial[I1]; + const auto gemmM = a_grid_desc_kbatch_k0_m_k1.GetLength(I1); + const auto gemmN = b_grid_desc_kbatch_k0_n_k1.GetLength(I1); + + const auto grid_size_mn = GridwiseGemm::Block2CTileMap::CalculateGridSize(gemmM, gemmN); + k_batch_ = get_best_occupancy_k_batch_value(max_occupancy.value_, grid_size_mn, Conv_G_); + + // Ensure that k_batch_ does not exceed the maximum value + // for the GEMM pipeline + ck::index_t gemmK; + std::tie(std::ignore, std::ignore, gemmK) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); + const auto k_batch_max = static_cast((gemmK - 1) / KPerBlock); + k_batch_ = std::min(k_batch_, k_batch_max); + + if (ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " + << k_batch_max << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " + << k_batch_ << std::endl; + } + } + else + { + k_batch_ = split_k; + } + const auto descs = conv_to_gemm_transformer_v2 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -751,7 +854,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const std::array& conv_filter_strides_; const std::array& input_left_pads_; const std::array& input_right_pads_; - const index_t k_batch_; long_index_t c_space_size_bytes; };