[rocm-libraries] ROCm/rocm-libraries#4273 (commit 591f504)

[CK] Add fwd conv group merging to v3 conv instances
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

Added conv group merging to the (universal) V3 fwd conv pipeline. The
new instance improves fwd conv performance when the number of
input/output channel per group is low.

On MI300 (`gfx942`) we get

| CK prof command | Baseline (TFLOPS) | V3 group merging (TFLOPS) |
|:-----|:------:|------:|
| grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 4 4 3 3 200 200 1 1 1 1 1 1 1 1
| 3.86035 | 8.36796 |
| grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 8 8 3 3 200 200 2 2 1 1 1 1 1 1
| 10.1867 | 13.4677 |
| grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 8 8 3 3 100 100 1 2 1 1 1 1 1 1
| 11.7875 | 16.3657 |
This commit is contained in:
Ville Pietilä
2026-02-08 11:35:56 +00:00
committed by assistant-librarian[bot]
parent 4266f867d6
commit 57d26db844
19 changed files with 140 additions and 46 deletions

View File

@@ -69,7 +69,7 @@ template <typename T>
concept FwdXdlV3Algorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseFwdXdlGemm<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;
// FWD WMMA algorithm concepts
template <typename T>

View File

@@ -161,7 +161,8 @@ struct ConvFwdXdlV3Factory
BLOCK_GEMM.pipeline_version,
typename Types::InComputeType,
typename Types::WeiComputeType,
IS_DIRECT_LOAD>;
IS_DIRECT_LOAD,
ALGORITHM.num_conv_groups_to_merge>;
};
} // namespace ck_tile::builder::factory

View File

@@ -71,7 +71,8 @@ template <index_t NDimSpatial,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AComputeDataType,
typename BComputeDataType,
bool DirectLoad>
bool DirectLoad,
index_t NumGroupsToMerge>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
} // namespace ck::tensor_operation::device
@@ -132,7 +133,8 @@ template <ck::index_t NDimSpatial,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AComputeDataType_,
typename BComputeDataType_,
bool DirectLoad>
bool DirectLoad,
index_t NumGroupsToMerge>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
NDimSpatial,
ALayout_,
@@ -182,7 +184,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
BlkGemmPipelineVer,
AComputeDataType_,
BComputeDataType_,
DirectLoad>>
DirectLoad,
NumGroupsToMerge>>
{
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag;
@@ -270,6 +273,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
static constexpr bool kDirectLoad = DirectLoad;
static constexpr int kNumGroupsToMerge = NumGroupsToMerge;
// Static member function to generate instance string
static std::string instance_string()
{
@@ -351,6 +356,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
oss << "," << detail::type_name<AComputeDataType>(); // 47. AComputeDataType
oss << "," << detail::type_name<BComputeDataType>(); // 48. BComputeDataType
oss << "," << (DirectLoad ? "true" : "false"); // 49. DirectLoad
oss << "," << kNumGroupsToMerge; // 50. NumGroupsToMerge
oss << ">";
return oss.str();