mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4273 (commit 591f504)
[CK] Add fwd conv group merging to v3 conv instances MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Added conv group merging to the (universal) V3 fwd conv pipeline. The new instance improves fwd conv performance when the number of input/output channel per group is low. On MI300 (`gfx942`) we get | CK prof command | Baseline (TFLOPS) | V3 group merging (TFLOPS) | |:-----|:------:|------:| | grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 4 4 3 3 200 200 1 1 1 1 1 1 1 1 | 3.86035 | 8.36796 | | grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 8 8 3 3 200 200 2 2 1 1 1 1 1 1 | 10.1867 | 13.4677 | | grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 8 8 3 3 100 100 1 2 1 1 1 1 1 1 | 11.7875 | 16.3657 |
This commit is contained in:
committed by
assistant-librarian[bot]
parent
4266f867d6
commit
57d26db844
@@ -382,7 +382,8 @@ template <index_t NDimSpatial,
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
typename BComputeDataType = AComputeDataType,
|
||||
bool DirectLoad = false>
|
||||
bool DirectLoad = false,
|
||||
index_t NumGroupsToMerge = 1>
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
ALayout,
|
||||
@@ -418,6 +419,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
Wave32MaxMNPerXDL,
|
||||
NXdlPerWave*(NPerXDL / Wave32MaxMNPerXDL)>();
|
||||
|
||||
static_assert(NumGroupsToMerge >= 1);
|
||||
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
static constexpr bool isMultiD = DsDataType::Size() > 0;
|
||||
@@ -447,7 +450,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
ConvForwardSpecialization,
|
||||
true /*SplitN*/,
|
||||
ADataType,
|
||||
EDataType>;
|
||||
EDataType,
|
||||
NumGroupsToMerge>;
|
||||
|
||||
using ComputePtrOffset = ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>;
|
||||
|
||||
@@ -784,8 +788,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
// A/B/E Batch/N Stride
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ =
|
||||
a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;
|
||||
|
||||
// p_as and p_bs are pointers
|
||||
@@ -796,7 +801,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
|
||||
|
||||
@@ -816,7 +822,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
|
||||
});
|
||||
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ =
|
||||
e_g_n_k_wos_strides_[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
|
||||
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
@@ -999,7 +1006,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/);
|
||||
|
||||
gdy = arg.num_group_;
|
||||
gdy = arg.num_group_ / NumGroupsToMerge;
|
||||
gdz = num_workgroups_per_Conv_N;
|
||||
|
||||
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
|
||||
@@ -1499,6 +1506,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
if(G % NumGroupsToMerge != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported! G % NumGroupsToMerge != 0: G=" << G
|
||||
<< ", NumGroupsToMerge=" << NumGroupsToMerge << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(get_device_name() == "gfx908")
|
||||
{
|
||||
// FIXME: re-enable fp64 when SWDEV-335738 is fixed
|
||||
@@ -1595,6 +1615,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3)
|
||||
{
|
||||
if(C != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3];
|
||||
|
||||
if(filter_spatial_dim != I3)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check vector access of A
|
||||
// FIXME: layout
|
||||
@@ -2106,6 +2142,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
if constexpr(DirectLoad) {
|
||||
str << "_DirectLoad";
|
||||
}
|
||||
if constexpr (NumGroupsToMerge > 1) {
|
||||
str << "_MergedGroups";
|
||||
}
|
||||
|
||||
str << "<"
|
||||
<< BlockSize << ", "
|
||||
@@ -2125,8 +2164,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
|
||||
<< ">";
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer];
|
||||
if constexpr (NumGroupsToMerge > 1) {
|
||||
str << ", " << NumGroupsToMerge;
|
||||
}
|
||||
str << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -1380,11 +1380,11 @@ struct TransformConvFwdToGemm
|
||||
else
|
||||
{
|
||||
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K_, NumGroupsToMerge, ZYX_ * C_),
|
||||
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
|
||||
make_tuple(NumGroupsToMerge, K_, ZYX_ * C_),
|
||||
make_tuple(GStrideTensorB_, KStrideTensorB_, CStrideTensorB_));
|
||||
return transform_tensor_descriptor(
|
||||
wei_gemmn_groups_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
|
||||
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
|
||||
make_pass_through_transform(ZYX_ * C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
@@ -1550,20 +1550,20 @@ struct TransformConvFwdToGemm
|
||||
else
|
||||
{
|
||||
const auto nhwo_groups_k_1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
|
||||
make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, 1, K_),
|
||||
make_tuple(NStrideTensorC_,
|
||||
HoStride_,
|
||||
WoStride_,
|
||||
GStrideTensorC_,
|
||||
KStrideTensorC_,
|
||||
GStrideTensorC_));
|
||||
GStrideTensorC_,
|
||||
KStrideTensorC_));
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
// We need only matrices from diagonal. X_or returns 0 for the same
|
||||
@@ -1577,13 +1577,13 @@ struct TransformConvFwdToGemm
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user