[rocm-libraries] ROCm/rocm-libraries#4273 (commit 591f504)

[CK] Add fwd conv group merging to v3 conv instances
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

Added conv group merging to the (universal) V3 fwd conv pipeline. The
new instance improves fwd conv performance when the number of
input/output channel per group is low.

On MI300 (`gfx942`) we get

| CK prof command | Baseline (TFLOPS) | V3 group merging (TFLOPS) |
|:-----|:------:|------:|
| grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 4 4 3 3 200 200 1 1 1 1 1 1 1 1
| 3.86035 | 8.36796 |
| grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 8 8 3 3 200 200 2 2 1 1 1 1 1 1
| 10.1867 | 13.4677 |
| grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 8 8 3 3 100 100 1 2 1 1 1 1 1 1
| 11.7875 | 16.3657 |
This commit is contained in:
Ville Pietilä
2026-02-08 11:35:56 +00:00
committed by assistant-librarian[bot]
parent 4266f867d6
commit 57d26db844
19 changed files with 140 additions and 46 deletions

View File

@@ -35,7 +35,8 @@ TEST(FwdConvInstances,
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v2_intrawave);
.with_block_gemm(BlockGemmDesc_v2_intrawave)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -31,7 +31,8 @@ TEST(FwdConvInstances,
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v1_intrawave);
.with_block_gemm(BlockGemmDesc_v1_intrawave)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
@@ -69,7 +70,8 @@ TEST(FwdConvInstances,
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_3x3,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v5_intrawave);
.with_block_gemm(BlockGemmDesc_v5_intrawave)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -32,7 +32,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xd
.with_transfer(cku::Transfer_4x64x1)
.with_fwd_specializations(ckb::ConvSpecialization::DEFAULT,
ckb::GemmSpecialization::MNKPadding)
.with_block_gemm(cku::BlockGemmDesc_v3_intrawave);
.with_block_gemm(cku::BlockGemmDesc_v3_intrawave)
.with_num_conv_groups_to_merge(1);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;

View File

@@ -31,7 +31,8 @@ TEST(FwdConvInstances,
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v3_intrawave);
.with_block_gemm(BlockGemmDesc_v3_intrawave)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -32,7 +32,8 @@ TEST(FwdConvInstances,
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v4_intrawave);
.with_block_gemm(BlockGemmDesc_v4_intrawave)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -32,7 +32,8 @@ TEST(FwdConvInstances,
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v1_intrawave);
.with_block_gemm(BlockGemmDesc_v1_intrawave)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -1188,7 +1188,8 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
false>; // DirectLoad
false, // DirectLoad
1>; // NumGroupsToMerge
// Use ConvTraitsTmpl to extract compile-time information
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();

View File

@@ -97,6 +97,7 @@ static constexpr int kCShuffleMXdlPerWavePerShuffle = 1;
static constexpr int kCShuffleNXdlPerWavePerShuffle = 1;
static constexpr int kCDEBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr bool kDirectLoad = false;
static constexpr int kNumGroupsToMerge = 1;
using DefaultABlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>;
using DefaultABlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>;
@@ -176,7 +177,8 @@ using DeviceInstanceForTests_V3 =
BlkGemmPipelineVer,
ADataType,
BDataType,
defaults::kDirectLoad>;
defaults::kDirectLoad,
defaults::kNumGroupsToMerge>;
// Test case helper for specialization testing
template <ck::tensor_operation::device::ConvolutionForwardSpecialization Spec>

View File

@@ -102,7 +102,8 @@ TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffleV3)
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
false>; // DirectLoad
false, // DirectLoad
1>; // NumGroupsToMerge
using InstTraits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();