Added split-K autodeduction for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle.

This commit is contained in:
Ville Pietilä
2025-07-11 07:21:53 +00:00
parent e19f337b9a
commit 5bd4a60d36

View File

@@ -19,6 +19,8 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -543,7 +545,36 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
using Block2CTileMap =
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
struct Argument : public BaseArgument
struct MaximumActiveBlocksPerMultiprocessor
{
MaximumActiveBlocksPerMultiprocessor()
{
constexpr int dynamic_smem_size = 0;
int max_occupancy = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_batched_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType,
BDataType,
AccDataType,
OutElementwiseOperation,
InElementwiseOperation,
element_wise::PassThrough,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
true>,
BlockSize,
dynamic_smem_size));
value_ = std::max(1, max_occupancy);
}
int value_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(
const InDataType* p_in_grid,
@@ -592,9 +623,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
input_right_pads_{input_right_pads}
{
static MaximumActiveBlocksPerMultiprocessor max_occupancy;
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
@@ -611,6 +643,39 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
if (split_k < 0)
{
constexpr int k_batch_initial = 1;
const auto descs_initial =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides,
e_g_k_c_xs_strides,
a_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_initial);
const auto& ce_grid_desc_m_n = descs_initial[I2];
const auto& block_2_ctile_map =
GridwiseGemm::MakeCBlockClusterAdaptor(ce_grid_desc_m_n, M01, N01, k_batch_initial);
const auto grid_size = block_2_ctile_map.CalculateGridSize(ce_grid_desc_m_n) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy.value_, grid_size);
}
else
{
k_batch_ = split_k;
}
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -713,7 +778,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};