Add support for occupancy-based splitk

This commit is contained in:
Enrico Degregori
2025-08-14 09:19:58 +00:00
parent 45b3d26e3c
commit b56e9f6bc4
3 changed files with 139 additions and 13 deletions

View File

@@ -22,6 +22,8 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -440,7 +442,42 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}, 1, 1));
struct Argument : public BaseArgument
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
int max_occupancy = 0;
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// TODO: implement
}
else
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
@@ -483,9 +520,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
input_right_pads_{input_right_pads}
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
constexpr index_t spatial_offset = 3;
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
end(b_g_n_c_wis_lengths),
@@ -507,6 +545,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
e_g_k_c_xs_strides);
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
std::tie(gemmM, gemmN, gemmK) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size = calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) *
Conv_G_ / NumGroupsToMerge;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
<< std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_
<< std::endl;
}
}
else
{
k_batch_ = split_k;
}
const auto descs =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -682,7 +749,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
};
// Invoker

View File

@@ -22,6 +22,8 @@
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
@@ -401,7 +403,41 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}, 1, 1));
struct Argument : public BaseArgument
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
int max_occupancy = 0;
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// TODO: implement
}
else
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
@@ -444,9 +480,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
input_right_pads_{input_right_pads}
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
@@ -463,6 +500,35 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
std::tie(gemmM, gemmN, gemmK) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size =
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
<< std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_
<< std::endl;
}
}
else
{
k_batch_ = split_k;
}
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
@@ -633,7 +699,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};