diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 672c7dd2f7..f685d80a04 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -19,6 +19,8 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -543,7 +545,36 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle using Block2CTileMap = decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); - struct Argument : public BaseArgument + struct MaximumActiveBlocksPerMultiprocessor + { + MaximumActiveBlocksPerMultiprocessor() + { + constexpr int dynamic_smem_size = 0; + int max_occupancy = 0; + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, + BDataType, + AccDataType, + OutElementwiseOperation, + InElementwiseOperation, + element_wise::PassThrough, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + true>, + BlockSize, + dynamic_smem_size)); + value_ = std::max(1, max_occupancy); + } + int value_; + }; + + struct Argument : public BaseArgument, public ArgumentSplitK { Argument( const InDataType* p_in_grid, @@ -592,9 +623,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads}, - k_batch_{split_k} + input_right_pads_{input_right_pads} { + static MaximumActiveBlocksPerMultiprocessor max_occupancy; + c_space_size_bytes = ck::accumulate_n( e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * @@ -611,6 +643,39 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); + if (split_k < 0) + { + constexpr int k_batch_initial = 1; + const auto descs_initial = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_initial); + + const auto& ce_grid_desc_m_n = descs_initial[I2]; + const auto& block_2_ctile_map = + GridwiseGemm::MakeCBlockClusterAdaptor(ce_grid_desc_m_n, M01, N01, k_batch_initial); + + const auto grid_size = block_2_ctile_map.CalculateGridSize(ce_grid_desc_m_n) * Conv_G_; + k_batch_ = get_best_occupancy_k_batch_value(max_occupancy.value_, grid_size); + } + else + { + k_batch_ = split_k; + } + const auto descs = conv_to_gemm_transformer .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -713,7 +778,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle const std::array& conv_filter_strides_; const std::array& input_left_pads_; const std::array& input_right_pads_; - const index_t k_batch_; long_index_t c_space_size_bytes; };