mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Automatic deduction of split-K value for grouped convolution (#2491)
* Split-K autodeduction for DeviceGroupedConvBwdWeight_Xdl_CShuffle and DeviceGroupedConvBwdWeight_Xdl_CShuffleV3. * Split-K autodeduction for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle. * Use simple best occupancy model to calculate the split-K. * Handle split-K autodeduction in explicit gemm conv. * Add unit tests for split-K autodeduction. * Remove oversubscription. * Small fixes. * Added split-K autodeduction for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle. * Run clang formatting. * Fix error handling in the conv profiler. * Add missing documentation for the autodeducted split-K values. * Add split-K autodeduction to DeviceGroupedConvBwdWeight_Explicit_Xdl solver. * Fix clang formatting and split-K profiler documentation. * Rename max_occupancy value variable. * Calculate grid size for split-K autodeduction directly from input array shapes and template params. --------- Co-authored-by: Ville Pietilä <>
This commit is contained in:
@@ -337,6 +337,60 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
}
|
||||
};
|
||||
|
||||
struct ActiveWorkgroupsPerCU
|
||||
{
|
||||
ActiveWorkgroupsPerCU()
|
||||
{
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
int max_occupancy = 0;
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds<
|
||||
GridwiseGemm,
|
||||
Argument,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
|
||||
GridwiseGemm,
|
||||
Argument,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
}
|
||||
|
||||
max_occupancy_ = std::max(1, max_occupancy);
|
||||
}
|
||||
int max_occupancy_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
@@ -1044,6 +1098,12 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
static ck::index_t GetMaxOccupancy()
|
||||
{
|
||||
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
|
||||
return active_workgroups_per_cu.max_occupancy_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -142,6 +144,20 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
end(e_g_k_c_xs_lengths),
|
||||
begin(filter_spatial_lengths_));
|
||||
|
||||
if(split_k < 0)
|
||||
{
|
||||
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
split_k_ = split_k;
|
||||
}
|
||||
|
||||
if constexpr(IsTwoStageNeeded)
|
||||
{
|
||||
const index_t merged_filter_dims = std::accumulate(begin(e_g_k_c_xs_lengths),
|
||||
@@ -176,7 +192,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
out_element_op,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
split_k};
|
||||
split_k_};
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -199,7 +215,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
out_element_op,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
split_k};
|
||||
split_k_};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,6 +252,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
bool is_filter_data_packed;
|
||||
CElementwiseGridDesc elementwise_desc_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_;
|
||||
ck::index_t split_k_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -542,7 +544,36 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
using Block2CTileMap =
|
||||
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
struct ActiveWorkgroupsPerCU
|
||||
{
|
||||
ActiveWorkgroupsPerCU()
|
||||
{
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
int max_occupancy = 0;
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_batched_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
element_wise::PassThrough,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
true>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
max_occupancy_ = std::max(1, max_occupancy);
|
||||
}
|
||||
int max_occupancy_;
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument, public ArgumentSplitK
|
||||
{
|
||||
Argument(
|
||||
const InDataType* p_in_grid,
|
||||
@@ -591,9 +622,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
output_spatial_lengths_{},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
|
||||
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
@@ -610,6 +642,22 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN;
|
||||
std::tie(gemmM, gemmN, std::ignore) =
|
||||
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
|
||||
|
||||
const auto grid_size =
|
||||
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
|
||||
grid_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -712,7 +760,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -504,7 +506,55 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}, 1, 1));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
struct ActiveWorkgroupsPerCU
|
||||
{
|
||||
ActiveWorkgroupsPerCU()
|
||||
{
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
int max_occupancy = 0;
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
NumGroupsToMerge,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
NumGroupsToMerge,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
}
|
||||
max_occupancy_ = std::max(1, max_occupancy);
|
||||
}
|
||||
int max_occupancy_;
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument, public ArgumentSplitK
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
@@ -547,9 +597,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
output_spatial_lengths_{},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
|
||||
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
@@ -576,6 +627,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
|
||||
e_g_k_c_xs_strides);
|
||||
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN, gemmK;
|
||||
std::tie(gemmM, gemmN, gemmK) =
|
||||
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
|
||||
|
||||
const auto grid_size = calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) *
|
||||
Conv_G_ / NumGroupsToMerge;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
|
||||
grid_size);
|
||||
|
||||
// Ensure that k_batch_ does not exceed the maximum value
|
||||
// for the GEMM pipeline.
|
||||
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / KPerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
|
||||
<< std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer_v2
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -751,7 +831,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
@@ -419,7 +421,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
using Block2CTileMap =
|
||||
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
struct ActiveWorkgroupsPerCU
|
||||
{
|
||||
ActiveWorkgroupsPerCU()
|
||||
{
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
int max_occupancy = 0;
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_batched_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
false>, // Both true/false give the same occupancy.
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
max_occupancy_ = std::max(1, max_occupancy);
|
||||
}
|
||||
int max_occupancy_;
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument, public ArgumentSplitK
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
@@ -463,9 +494,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
output_spatial_lengths_{},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
|
||||
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
@@ -491,6 +523,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
std::array<index_t, NDimSpatial + 3> e_g_k_c_xs_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
|
||||
e_g_k_c_xs_strides);
|
||||
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN;
|
||||
std::tie(gemmM, gemmN, std::ignore) =
|
||||
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
|
||||
|
||||
const auto grid_size =
|
||||
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
|
||||
grid_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -656,7 +705,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
@@ -381,7 +383,53 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}, 1, 1));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
struct ActiveWorkgroupsPerCU
|
||||
{
|
||||
ActiveWorkgroupsPerCU()
|
||||
{
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
int max_occupancy = 0;
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
}
|
||||
max_occupancy_ = std::max(1, max_occupancy);
|
||||
}
|
||||
int max_occupancy_;
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument, public ArgumentSplitK
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
@@ -424,9 +472,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
output_spatial_lengths_{},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
|
||||
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
@@ -443,6 +492,35 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN, gemmK;
|
||||
std::tie(gemmM, gemmN, gemmK) =
|
||||
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
|
||||
|
||||
const auto grid_size =
|
||||
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
|
||||
grid_size);
|
||||
|
||||
// Ensure that k_batch_ does not exceed the maximum value
|
||||
// for the GEMM pipeline.
|
||||
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
|
||||
<< std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -513,7 +591,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
|
||||
17
include/ck/tensor_operation/gpu/device/impl/split_k_arg.hpp
Normal file
17
include/ck/tensor_operation/gpu/device/impl/split_k_arg.hpp
Normal file
@@ -0,0 +1,17 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
struct ArgumentSplitK
|
||||
{
|
||||
index_t k_batch_{1};
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <numeric>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
struct DeviceProperties
|
||||
{
|
||||
DeviceProperties()
|
||||
{
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
|
||||
num_cu_ = dev_prop.multiProcessorCount;
|
||||
};
|
||||
int num_cu_;
|
||||
};
|
||||
|
||||
inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
|
||||
{
|
||||
static DeviceProperties device_properties;
|
||||
const int max_capacity = max_occupancy * device_properties.num_cu_;
|
||||
|
||||
ck::index_t k_batch = 1;
|
||||
const auto optimal_split =
|
||||
static_cast<ck::index_t>(std::floor((1.0 * max_capacity) / grid_size));
|
||||
if(optimal_split > 1)
|
||||
{
|
||||
k_batch = optimal_split;
|
||||
}
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: "
|
||||
<< max_occupancy << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
|
||||
}
|
||||
return k_batch;
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
inline auto
|
||||
get_bwd_weight_gemm_sizes(const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths)
|
||||
{
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
// The input array has elements in the order: G, N, K, Do, Ho, Wo
|
||||
// GemmK = N * Do * Ho * Wo for the BWD weight pass.
|
||||
constexpr index_t spatial_offset = 3;
|
||||
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
|
||||
end(a_g_n_k_wos_lengths),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
const auto gemmK = a_g_n_k_wos_lengths[I1] * DoHoWo;
|
||||
|
||||
// The GEMM M dimension is the number of output channels.
|
||||
const auto gemmM = e_g_k_c_xs_lengths[I1];
|
||||
|
||||
// The output array has elements in the order: G, K, C, X, Y, Z
|
||||
// GemmN = C * X * Y * Z for the BWD weight pass.
|
||||
const index_t XYZ = std::accumulate(begin(e_g_k_c_xs_lengths) + spatial_offset,
|
||||
end(e_g_k_c_xs_lengths),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
const auto gemmN = e_g_k_c_xs_lengths[I2] * XYZ;
|
||||
return std::make_tuple(gemmM, gemmN, gemmK);
|
||||
}
|
||||
|
||||
template <ck::index_t MPerBlock, ck::index_t NPerBlock>
|
||||
inline ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(gemmM, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(gemmN, NPerBlock);
|
||||
return M0 * N0;
|
||||
}
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -148,7 +148,7 @@
|
||||
# <dilations>, (ie Dy, Dx for 2D)
|
||||
# <left padding>, (ie LeftPy, LeftPx for 2D)
|
||||
# <right padding>, (ie RightPy, RightPx for 2D)
|
||||
# SplitK
|
||||
# SplitK (-1 for internally computed split-K value, positive value to set k batches explicitly, or 'all' to test all internal split-K values)
|
||||
|
||||
################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK
|
||||
./bin/ckProfiler grouped_conv_bwd_weight 1 1 0 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp"
|
||||
@@ -40,7 +41,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
ck::index_t split_k)
|
||||
const std::string& split_k)
|
||||
{
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
@@ -138,10 +139,10 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
ck::index_t best_split_k = 1;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
std::string best_split_k("1");
|
||||
|
||||
// profile device Conv instances
|
||||
bool all_pass = true;
|
||||
@@ -170,11 +171,20 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
|
||||
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
|
||||
|
||||
std::vector<ck::index_t> split_k_list = {1, 2, 4, 8, 16, 32, 64, 128};
|
||||
std::vector<ck::index_t> split_k_list = {/*auto deduce value*/ -1, 1, 2, 4, 8, 16, 32, 64, 128};
|
||||
|
||||
if(split_k > 0)
|
||||
if(split_k != "all")
|
||||
{
|
||||
split_k_list = {split_k};
|
||||
try
|
||||
{
|
||||
ck::index_t split_k_value = std::stoi(split_k);
|
||||
split_k_list = {split_k_value};
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << e.what() << '\n';
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
@@ -200,6 +210,16 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
out_element_op,
|
||||
split_k_list[split_k_id]);
|
||||
|
||||
auto split_k_value = split_k_list[split_k_id];
|
||||
auto split_k_param_str = std::to_string(split_k_value);
|
||||
auto* split_k_arg =
|
||||
dynamic_cast<ck::tensor_operation::device::ArgumentSplitK*>(argument_ptr.get());
|
||||
if(split_k_arg && split_k_value < 0)
|
||||
{
|
||||
split_k_value = split_k_arg->k_batch_;
|
||||
split_k_param_str = std::to_string(split_k_value) + " (best occupancy)";
|
||||
}
|
||||
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
@@ -222,7 +242,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", SplitK "
|
||||
<< split_k_list[split_k_id] << std::endl;
|
||||
<< split_k_param_str << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
@@ -230,7 +250,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
best_tflops = tflops;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_split_k = split_k_list[split_k_id];
|
||||
best_split_k = split_k_param_str;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
@@ -244,7 +264,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
const index_t num_accums = output.GetElementSize() / conv_param.K_;
|
||||
const index_t num_accums_split_k = split_k_list[split_k_id];
|
||||
const index_t num_accums_split_k = split_k_value;
|
||||
// Calculate thresholds
|
||||
auto rtol =
|
||||
ck::utils::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(
|
||||
|
||||
@@ -56,7 +56,9 @@ static void print_helper_msg()
|
||||
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg7: time kernel (0: no, 1: yes)\n"
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg() << " SplitK\n"
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg()
|
||||
<< " SplitK (-1 for internally computed split-K value, positive value to set k "
|
||||
"batches explicitly, or 'all' to test all internal split-K values)\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
@@ -88,7 +90,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
|
||||
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv);
|
||||
|
||||
ck::index_t split_k = std::stoi(argv[8 + 1 + 4 + 6 * num_dim_spatial]);
|
||||
const auto& split_k = std::string(argv[8 + 1 + 4 + 6 * num_dim_spatial]);
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
|
||||
@@ -30,7 +30,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
using NDimSpatial = std::tuple_element_t<6, Tuple>;
|
||||
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
std::vector<ck::index_t> split_ks{1, 2};
|
||||
std::vector<ck::index_t> split_ks{-1, 1, 2};
|
||||
|
||||
bool skip_case(const ck::index_t split_k)
|
||||
{
|
||||
@@ -108,7 +108,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param,
|
||||
split_k);
|
||||
std::to_string(split_k));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
// clang-format on
|
||||
|
||||
ck::utils::conv::ConvParam conv_param;
|
||||
ck::index_t split_k{2};
|
||||
std::vector<ck::index_t> split_ks{-1, 2};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
bool Run()
|
||||
@@ -96,24 +96,30 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
|
||||
auto conv = GroupedConvBwdWeightDeviceInstance{};
|
||||
|
||||
auto argument = conv.MakeArgument(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
split_k);
|
||||
return conv.IsSupportedArgument(argument);
|
||||
bool is_supported = true;
|
||||
|
||||
for(const auto split_k : split_ks)
|
||||
{
|
||||
auto argument = conv.MakeArgument(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
split_k);
|
||||
is_supported &= conv.IsSupportedArgument(argument);
|
||||
}
|
||||
return is_supported;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
// clang-format on
|
||||
|
||||
ck::utils::conv::ConvParam conv_param;
|
||||
ck::index_t split_k{2};
|
||||
std::vector<ck::index_t> split_ks{-1, 2};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
bool Run()
|
||||
@@ -96,24 +96,30 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
|
||||
auto conv = GroupedConvBwdWeightDeviceInstance{};
|
||||
|
||||
auto argument = conv.MakeArgument(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
split_k);
|
||||
return conv.IsSupportedArgument(argument);
|
||||
bool is_supported = true;
|
||||
|
||||
for(const auto split_k : split_ks)
|
||||
{
|
||||
auto argument = conv.MakeArgument(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
split_k);
|
||||
is_supported &= conv.IsSupportedArgument(argument);
|
||||
}
|
||||
return is_supported;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user