mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4273 (commit 591f504)
[CK] Add fwd conv group merging to v3 conv instances MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Added conv group merging to the (universal) V3 fwd conv pipeline. The new instance improves fwd conv performance when the number of input/output channel per group is low. On MI300 (`gfx942`) we get | CK prof command | Baseline (TFLOPS) | V3 group merging (TFLOPS) | |:-----|:------:|------:| | grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 4 4 3 3 200 200 1 1 1 1 1 1 1 1 | 3.86035 | 8.36796 | | grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 8 8 3 3 200 200 2 2 1 1 1 1 1 1 | 10.1867 | 13.4677 | | grouped_conv_fwd 1 1 1 0 1 0 1 2 32 32 8 8 3 3 100 100 1 2 1 1 1 1 1 1 | 11.7875 | 16.3657 |
This commit is contained in:
committed by
assistant-librarian[bot]
parent
4266f867d6
commit
57d26db844
@@ -69,7 +69,7 @@ template <typename T>
|
||||
concept FwdXdlV3Algorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseFwdXdlGemm<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;
|
||||
|
||||
// FWD WMMA algorithm concepts
|
||||
template <typename T>
|
||||
|
||||
@@ -161,7 +161,8 @@ struct ConvFwdXdlV3Factory
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
typename Types::InComputeType,
|
||||
typename Types::WeiComputeType,
|
||||
IS_DIRECT_LOAD>;
|
||||
IS_DIRECT_LOAD,
|
||||
ALGORITHM.num_conv_groups_to_merge>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
|
||||
@@ -71,7 +71,8 @@ template <index_t NDimSpatial,
|
||||
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AComputeDataType,
|
||||
typename BComputeDataType,
|
||||
bool DirectLoad>
|
||||
bool DirectLoad,
|
||||
index_t NumGroupsToMerge>
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
@@ -132,7 +133,8 @@ template <ck::index_t NDimSpatial,
|
||||
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AComputeDataType_,
|
||||
typename BComputeDataType_,
|
||||
bool DirectLoad>
|
||||
bool DirectLoad,
|
||||
index_t NumGroupsToMerge>
|
||||
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
NDimSpatial,
|
||||
ALayout_,
|
||||
@@ -182,7 +184,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
BlkGemmPipelineVer,
|
||||
AComputeDataType_,
|
||||
BComputeDataType_,
|
||||
DirectLoad>>
|
||||
DirectLoad,
|
||||
NumGroupsToMerge>>
|
||||
{
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag;
|
||||
@@ -270,6 +273,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
|
||||
static constexpr bool kDirectLoad = DirectLoad;
|
||||
|
||||
static constexpr int kNumGroupsToMerge = NumGroupsToMerge;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
@@ -351,6 +356,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
oss << "," << detail::type_name<AComputeDataType>(); // 47. AComputeDataType
|
||||
oss << "," << detail::type_name<BComputeDataType>(); // 48. BComputeDataType
|
||||
oss << "," << (DirectLoad ? "true" : "false"); // 49. DirectLoad
|
||||
oss << "," << kNumGroupsToMerge; // 50. NumGroupsToMerge
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
|
||||
@@ -35,7 +35,8 @@ TEST(FwdConvInstances,
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v2_intrawave);
|
||||
.with_block_gemm(BlockGemmDesc_v2_intrawave)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -31,7 +31,8 @@ TEST(FwdConvInstances,
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -69,7 +70,8 @@ TEST(FwdConvInstances,
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_3x3,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v5_intrawave);
|
||||
.with_block_gemm(BlockGemmDesc_v5_intrawave)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xd
|
||||
.with_transfer(cku::Transfer_4x64x1)
|
||||
.with_fwd_specializations(ckb::ConvSpecialization::DEFAULT,
|
||||
ckb::GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v3_intrawave);
|
||||
.with_block_gemm(cku::BlockGemmDesc_v3_intrawave)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
@@ -31,7 +31,8 @@ TEST(FwdConvInstances,
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v3_intrawave);
|
||||
.with_block_gemm(BlockGemmDesc_v3_intrawave)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ TEST(FwdConvInstances,
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave);
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ TEST(FwdConvInstances,
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -1188,7 +1188,8 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
false>; // DirectLoad
|
||||
false, // DirectLoad
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
// Use ConvTraitsTmpl to extract compile-time information
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
|
||||
@@ -97,6 +97,7 @@ static constexpr int kCShuffleMXdlPerWavePerShuffle = 1;
|
||||
static constexpr int kCShuffleNXdlPerWavePerShuffle = 1;
|
||||
static constexpr int kCDEBlockTransferScalarPerVector_NPerBlock = 8;
|
||||
static constexpr bool kDirectLoad = false;
|
||||
static constexpr int kNumGroupsToMerge = 1;
|
||||
|
||||
using DefaultABlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>;
|
||||
using DefaultABlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>;
|
||||
@@ -176,7 +177,8 @@ using DeviceInstanceForTests_V3 =
|
||||
BlkGemmPipelineVer,
|
||||
ADataType,
|
||||
BDataType,
|
||||
defaults::kDirectLoad>;
|
||||
defaults::kDirectLoad,
|
||||
defaults::kNumGroupsToMerge>;
|
||||
|
||||
// Test case helper for specialization testing
|
||||
template <ck::tensor_operation::device::ConvolutionForwardSpecialization Spec>
|
||||
|
||||
@@ -102,7 +102,8 @@ TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffleV3)
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
false>; // DirectLoad
|
||||
false, // DirectLoad
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
using InstTraits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
|
||||
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
|
||||
|
||||
@@ -566,7 +566,8 @@ using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
FwdXdlGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationFwd_,
|
||||
BlockGemm_>;
|
||||
BlockGemm_,
|
||||
GemmBatchOptions_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
|
||||
@@ -164,6 +164,7 @@ struct DefaultAlgorithm
|
||||
ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4,
|
||||
.scheduler =
|
||||
ckb::PipelineScheduler::INTRAWAVE};
|
||||
size_t num_conv_groups_to_merge = 1;
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
|
||||
|
||||
|
||||
@@ -83,7 +83,8 @@ TEST(InstanceTraits, V3ExtractsAllFieldsCorrectly)
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
false>;
|
||||
false, // DirectLoad
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
// Use InstanceTraits to extract compile-time information
|
||||
using Traits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
|
||||
@@ -225,7 +226,8 @@ TEST(InstanceTraits, V3InstanceStringReturnsCorrectFormat)
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
false>; // DirectLoad
|
||||
false, // DirectLoad
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
||||
|
||||
@@ -278,7 +280,8 @@ TEST(InstanceTraits, V3InstanceStringReturnsCorrectFormat)
|
||||
",v1" // BlkGemmPipelineVer
|
||||
",fp16" // AComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",false>"; // DirectLoad
|
||||
",false" // DirectLoad
|
||||
",1>"; // NumGroupsToMerge
|
||||
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
@@ -77,7 +77,8 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
|
||||
",v4" // BlkGemmPipelineVer
|
||||
",fp16" // AComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",false>"; // DirectLoad
|
||||
",false" // DirectLoad
|
||||
",1>"; // NumGroupsToMerge
|
||||
|
||||
// Test describe() through base class pointer for V3 variant
|
||||
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvV3)
|
||||
|
||||
Reference in New Issue
Block a user