mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 20:27:42 +00:00
Split-K autodeduction for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle.
This commit is contained in:
@@ -22,6 +22,8 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -504,7 +506,57 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}, 1, 1));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
struct MaximumActiveBlocksPerMultiprocessor
|
||||
{
|
||||
MaximumActiveBlocksPerMultiprocessor()
|
||||
{
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
int max_occupancy = 0;
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
NumGroupsToMerge,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
NumGroupsToMerge,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
}
|
||||
value_ = std::max(1, max_occupancy);
|
||||
}
|
||||
int value_;
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument, public ArgumentSplitK
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
@@ -547,9 +599,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
output_spatial_lengths_{},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
static MaximumActiveBlocksPerMultiprocessor max_occupancy;
|
||||
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
@@ -576,6 +629,56 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
|
||||
e_g_k_c_xs_strides);
|
||||
|
||||
if (split_k < 0)
|
||||
{
|
||||
constexpr int k_batch_initial = 1;
|
||||
const auto descs_initial =
|
||||
conv_to_gemm_transformer_v2
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides,
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_initial);
|
||||
|
||||
const auto& a_grid_desc_kbatch_k0_m_k1 = descs_initial[I0];
|
||||
const auto& b_grid_desc_kbatch_k0_n_k1 = descs_initial[I1];
|
||||
const auto gemmM = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
|
||||
const auto gemmN = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
|
||||
|
||||
const auto grid_size_mn = GridwiseGemm::Block2CTileMap::CalculateGridSize(gemmM, gemmN);
|
||||
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy.value_, grid_size_mn, Conv_G_);
|
||||
|
||||
// Ensure that k_batch_ does not exceed the maximum value
|
||||
// for the GEMM pipeline
|
||||
ck::index_t gemmK;
|
||||
std::tie(std::ignore, std::ignore, gemmK) =
|
||||
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
|
||||
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / KPerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
|
||||
if (ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: "
|
||||
<< k_batch_max << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: "
|
||||
<< k_batch_ << std::endl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer_v2
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -751,7 +854,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user