[rocm-libraries] ROCm/rocm-libraries#4273 (commit 591f504)

[CK] Add fwd conv group merging to v3 conv instances
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

Added conv group merging to the (universal) V3 fwd conv pipeline. The
new instance improves fwd conv performance when the number of
input/output channel per group is low.

On MI300 (`gfx942`) we get

| CK prof command | Baseline (TFLOPS) | V3 group merging (TFLOPS) |
|:-----|:------:|------:|
| grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 4 4 3 3 200 200 1 1 1 1 1 1 1 1
| 3.86035 | 8.36796 |
| grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 8 8 3 3 200 200 2 2 1 1 1 1 1 1
| 10.1867 | 13.4677 |
| grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 8 8 3 3 100 100 1 2 1 1 1 1 1 1
| 11.7875 | 16.3657 |
This commit is contained in:
Ville Pietilä
2026-02-08 11:35:56 +00:00
committed by assistant-librarian[bot]
parent 4266f867d6
commit 57d26db844
19 changed files with 140 additions and 46 deletions

View File

@@ -382,7 +382,8 @@ template <index_t NDimSpatial,
// in tuple for MultiAB), unpack if tuple was
// passed
typename BComputeDataType = AComputeDataType,
bool DirectLoad = false>
bool DirectLoad = false,
index_t NumGroupsToMerge = 1>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout,
@@ -418,6 +419,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
Wave32MaxMNPerXDL,
NXdlPerWave*(NPerXDL / Wave32MaxMNPerXDL)>();
static_assert(NumGroupsToMerge >= 1);
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
static constexpr bool isMultiD = DsDataType::Size() > 0;
@@ -447,7 +450,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType>;
EDataType,
NumGroupsToMerge>;
using ComputePtrOffset = ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>;
@@ -784,8 +788,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op_{cde_element_op}
{
// A/B/E Batch/N Stride
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0];
compute_ptr_offset_of_groups_.BatchStrideA_ =
a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;
// p_as and p_bs are pointers
@@ -796,7 +801,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// D batch stride
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
compute_ptr_offset_of_groups_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
@@ -816,7 +822,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
});
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0];
compute_ptr_offset_of_groups_.BatchStrideE_ =
e_g_n_k_wos_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
@@ -999,7 +1006,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/);
gdy = arg.num_group_;
gdy = arg.num_group_ / NumGroupsToMerge;
gdz = num_workgroups_per_Conv_N;
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
@@ -1499,6 +1506,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
}
if constexpr(NumGroupsToMerge > 1)
{
if(G % NumGroupsToMerge != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Unsupported! G % NumGroupsToMerge != 0: G=" << G
<< ", NumGroupsToMerge=" << NumGroupsToMerge << std::endl;
}
return false;
}
}
if(get_device_name() == "gfx908")
{
// FIXME: re-enable fp64 when SWDEV-335738 is fixed
@@ -1595,6 +1615,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
}
}
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3)
{
if(C != 1)
{
return false;
}
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3];
if(filter_spatial_dim != I3)
{
return false;
}
}
}
// check vector access of A
// FIXME: layout
@@ -2106,6 +2142,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
if constexpr(DirectLoad) {
str << "_DirectLoad";
}
if constexpr (NumGroupsToMerge > 1) {
str << "_MergedGroups";
}
str << "<"
<< BlockSize << ", "
@@ -2125,8 +2164,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
<< ">";
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer];
if constexpr (NumGroupsToMerge > 1) {
str << ", " << NumGroupsToMerge;
}
str << ">";
// clang-format on
return str.str();