mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4273 (commit 591f504)
[CK] Add fwd conv group merging to v3 conv instances MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Added conv group merging to the (universal) V3 fwd conv pipeline. The new instance improves fwd conv performance when the number of input/output channel per group is low. On MI300 (`gfx942`) we get | CK prof command | Baseline (TFLOPS) | V3 group merging (TFLOPS) | |:-----|:------:|------:| | grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 4 4 3 3 200 200 1 1 1 1 1 1 1 1 | 3.86035 | 8.36796 | | grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 8 8 3 3 200 200 2 2 1 1 1 1 1 1 | 10.1867 | 13.4677 | | grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 8 8 3 3 100 100 1 2 1 1 1 1 1 1 | 11.7875 | 16.3657 |
This commit is contained in:
committed by
assistant-librarian[bot]
parent
4266f867d6
commit
57d26db844
@@ -382,7 +382,8 @@ template <index_t NDimSpatial,
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
typename BComputeDataType = AComputeDataType,
|
||||
bool DirectLoad = false>
|
||||
bool DirectLoad = false,
|
||||
index_t NumGroupsToMerge = 1>
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
ALayout,
|
||||
@@ -418,6 +419,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
Wave32MaxMNPerXDL,
|
||||
NXdlPerWave*(NPerXDL / Wave32MaxMNPerXDL)>();
|
||||
|
||||
static_assert(NumGroupsToMerge >= 1);
|
||||
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
static constexpr bool isMultiD = DsDataType::Size() > 0;
|
||||
@@ -447,7 +450,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
ConvForwardSpecialization,
|
||||
true /*SplitN*/,
|
||||
ADataType,
|
||||
EDataType>;
|
||||
EDataType,
|
||||
NumGroupsToMerge>;
|
||||
|
||||
using ComputePtrOffset = ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>;
|
||||
|
||||
@@ -784,8 +788,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
// A/B/E Batch/N Stride
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ =
|
||||
a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;
|
||||
|
||||
// p_as and p_bs are pointers
|
||||
@@ -796,7 +801,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
|
||||
|
||||
@@ -816,7 +822,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
|
||||
});
|
||||
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ =
|
||||
e_g_n_k_wos_strides_[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
|
||||
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
@@ -999,7 +1006,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/);
|
||||
|
||||
gdy = arg.num_group_;
|
||||
gdy = arg.num_group_ / NumGroupsToMerge;
|
||||
gdz = num_workgroups_per_Conv_N;
|
||||
|
||||
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
|
||||
@@ -1499,6 +1506,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
if(G % NumGroupsToMerge != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported! G % NumGroupsToMerge != 0: G=" << G
|
||||
<< ", NumGroupsToMerge=" << NumGroupsToMerge << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(get_device_name() == "gfx908")
|
||||
{
|
||||
// FIXME: re-enable fp64 when SWDEV-335738 is fixed
|
||||
@@ -1595,6 +1615,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3)
|
||||
{
|
||||
if(C != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3];
|
||||
|
||||
if(filter_spatial_dim != I3)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check vector access of A
|
||||
// FIXME: layout
|
||||
@@ -2106,6 +2142,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
if constexpr(DirectLoad) {
|
||||
str << "_DirectLoad";
|
||||
}
|
||||
if constexpr (NumGroupsToMerge > 1) {
|
||||
str << "_MergedGroups";
|
||||
}
|
||||
|
||||
str << "<"
|
||||
<< BlockSize << ", "
|
||||
@@ -2125,8 +2164,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
|
||||
<< ">";
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer];
|
||||
if constexpr (NumGroupsToMerge > 1) {
|
||||
str << ", " << NumGroupsToMerge;
|
||||
}
|
||||
str << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
Reference in New Issue
Block a user