mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Making alorithm a non-type parameter
This simplifies the design by continuing to reduce the number of types and avoidng extra use of constexpr.
This commit is contained in:
@@ -37,7 +37,7 @@ static_assert(ThreadBlockInfo<ThreadBlock>);
|
||||
// Concept to check if struct provides thread block info.
|
||||
template <typename T>
|
||||
concept HasThreadBlockInfo = requires {
|
||||
{ T::THREAD_BLOCK } -> ThreadBlockInfo;
|
||||
{ T::thread_block } -> ThreadBlockInfo;
|
||||
};
|
||||
|
||||
// Concept for tuning parameters for a convolution problem.
|
||||
@@ -64,7 +64,7 @@ static_assert(ConvTuningInfo<ConvTuningParams>);
|
||||
// Concept to check if a struct provides convolution tuning info.
|
||||
template <typename T>
|
||||
concept HasConvTuningInfo = requires {
|
||||
{ T::TUNING_PARAMS } -> ConvTuningInfo;
|
||||
{ T::tuning_params } -> ConvTuningInfo;
|
||||
};
|
||||
|
||||
// No requirements yet for a ConvAlogorithm concept.
|
||||
|
||||
@@ -11,18 +11,25 @@
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
template <ConvSignature TSignature, ConvAlgorithm TAlgorithm, auto Version>
|
||||
/**
|
||||
* @brief Top-level builder for creating convolution kernel instances.
|
||||
*
|
||||
* This struct serves as the main entry point for generating a convolution kernel.
|
||||
* It uses a factory pattern based on the provided signature, algorithm, and version
|
||||
* to construct the appropriate kernel instance.
|
||||
*
|
||||
* @tparam TSignature The convolution signature, which describes the mathematical functionality of
|
||||
* the algorithm (e.g., data types, layouts, direction).
|
||||
* @tparam ALGORITHM The specific convolution algorithm to be used for the implementation.
|
||||
* @tparam Version The version of the builder implementation.
|
||||
*/
|
||||
template <ConvSignature TSignature, ConvAlgorithm auto ALGORITHM, auto Version>
|
||||
requires SupportedVersion<Version>
|
||||
struct ConvBuilder
|
||||
{
|
||||
// Input: Signature describes the mathematical funcationality of the algorithm.
|
||||
using Signature = TSignature;
|
||||
// Input: Algorithm describes the implementation of the algorithm.
|
||||
using Algorithm = TAlgorithm;
|
||||
// Input: Version of the builder, exposed for testing.
|
||||
using Signature = TSignature;
|
||||
static constexpr auto kVersion = Version;
|
||||
// Implmentation: The factory handles the builder logic.
|
||||
using builder = GroupedConvForwardXldCShuffleFactoryV3<Signature, Algorithm, Version>;
|
||||
using builder = GroupedConvForwardXldCShuffleFactoryV3<Signature, ALGORITHM, Version>;
|
||||
// Output: The kernel class.
|
||||
using Instance = builder::Instance;
|
||||
};
|
||||
|
||||
@@ -73,12 +73,13 @@ struct ConvBlock
|
||||
MNK<int> per_block;
|
||||
};
|
||||
|
||||
template <ConvAlgorithm Algo>
|
||||
template <ConvAlgorithm auto ALGORITHM>
|
||||
constexpr ConvBlock SetThreadBlockInfo()
|
||||
{
|
||||
if constexpr(HasThreadBlockInfo<Algo>)
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(HasThreadBlockInfo<AlgorithmType>)
|
||||
{
|
||||
constexpr auto& TB = Algo::THREAD_BLOCK;
|
||||
constexpr auto& TB = ALGORITHM.thread_block;
|
||||
return ConvBlock{
|
||||
.block_size = TB.block_size,
|
||||
.per_block = {.m = TB.sub_matrix.m, .n = TB.sub_matrix.n, .k = TB.sub_matrix.k}};
|
||||
@@ -101,19 +102,20 @@ struct ConvTuning
|
||||
int n_xdl_per_wave = 0;
|
||||
};
|
||||
|
||||
template <ConvAlgorithm Algo>
|
||||
template <ConvAlgorithm auto ALGORITHM>
|
||||
constexpr ConvTuning SetConvTuningInfo()
|
||||
{
|
||||
if constexpr(HasConvTuningInfo<Algo>)
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(HasConvTuningInfo<AlgorithmType>)
|
||||
{
|
||||
constexpr auto TI = Algo::TUNING_PARAMS;
|
||||
constexpr auto& TP = ALGORITHM.tuning_params;
|
||||
return ConvTuning{
|
||||
.ak1 = TI.ak1,
|
||||
.bk1 = TI.bk1,
|
||||
.ak1 = TP.ak1,
|
||||
.bk1 = TP.bk1,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_dxl = 32,
|
||||
.m_xdl_per_wave = TI.m_xdl_per_wave,
|
||||
.n_xdl_per_wave = TI.n_xdl_per_wave,
|
||||
.m_xdl_per_wave = TP.m_xdl_per_wave,
|
||||
.n_xdl_per_wave = TP.n_xdl_per_wave,
|
||||
};
|
||||
}
|
||||
// Default values.
|
||||
@@ -149,7 +151,7 @@ struct CBlockTransfer
|
||||
};
|
||||
|
||||
// Factory builds an instance of a grouped convolution kernel.
|
||||
template <ConvSignature Signature, ConvAlgorithm Algorithm, auto Version>
|
||||
template <ConvSignature Signature, ConvAlgorithm auto ALGORITHM, auto Version>
|
||||
requires SupportedVersion<Version>
|
||||
struct GroupedConvForwardXldCShuffleFactoryV3
|
||||
{
|
||||
@@ -161,8 +163,8 @@ struct GroupedConvForwardXldCShuffleFactoryV3
|
||||
.conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
};
|
||||
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<Algorithm>();
|
||||
static constexpr ConvTuning TUNING = SetConvTuningInfo<Algorithm>();
|
||||
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr ConvTuning TUNING = SetConvTuningInfo<ALGORITHM>();
|
||||
static constexpr BlockTransfer A_BLOCK_TRANSFER{
|
||||
.thread_cluster_lengths = {4, 64, 1},
|
||||
.thread_cluster_order = {1, 0, 2},
|
||||
|
||||
@@ -15,68 +15,49 @@ struct FwdConvSignature
|
||||
};
|
||||
static_assert(ckb::ConvSignature<FwdConvSignature>);
|
||||
|
||||
struct FwdConvAlgorithm
|
||||
struct DefaultFwdConvAlgorithm
|
||||
{
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithm<FwdConvAlgorithm>);
|
||||
static_assert(ckb::ConvAlgorithm<DefaultFwdConvAlgorithm>);
|
||||
|
||||
static constexpr char API_VERSION[] = "0.1.0";
|
||||
|
||||
TEST(ConvBuilderTest, TestDefaultInstance)
|
||||
{
|
||||
using Builder = ckb::ConvBuilder<FwdConvSignature, FwdConvAlgorithm, API_VERSION>;
|
||||
static constexpr DefaultFwdConvAlgorithm algorithm;
|
||||
using Builder = ckb::ConvBuilder<FwdConvSignature, algorithm, API_VERSION>;
|
||||
EXPECT_EQ(
|
||||
Builder::Instance::TypeString(),
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, "
|
||||
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
|
||||
}
|
||||
|
||||
[[maybe_unused]] static constexpr ckb::ThreadBlock THREAD_BLOCK_256_256_256_32{
|
||||
.block_size = 256,
|
||||
.sub_matrix = {.m = 256, .n = 256, .k = 32},
|
||||
};
|
||||
|
||||
struct ConvFwdXdlBf16CompInstances2xAlgorithm0
|
||||
struct FwdConvAlgorithm
|
||||
{
|
||||
static constexpr ckb::ThreadBlock THREAD_BLOCK{
|
||||
.block_size = 256,
|
||||
.sub_matrix = {.m = 256, .n = 128, .k = 64},
|
||||
};
|
||||
static constexpr ckb::ConvTuningParams TUNING_PARAMS{
|
||||
.ak1 = 16,
|
||||
.bk1 = 16,
|
||||
.m_xdl_per_wave = 2,
|
||||
.n_xdl_per_wave = 2,
|
||||
};
|
||||
ckb::ThreadBlock thread_block;
|
||||
ckb::ConvTuningParams tuning_params;
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithm<FwdConvAlgorithm>);
|
||||
static_assert(ckb::HasThreadBlockInfo<FwdConvAlgorithm>);
|
||||
static_assert(ckb::HasConvTuningInfo<FwdConvAlgorithm>);
|
||||
|
||||
TEST(ConvBuilderTest, TestConvFwdXdlBf16CompInstances2xInstance0)
|
||||
{
|
||||
using Builder =
|
||||
ckb::ConvBuilder<FwdConvSignature, ConvFwdXdlBf16CompInstances2xAlgorithm0, API_VERSION>;
|
||||
static constexpr FwdConvAlgorithm algorithm{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 128, .k = 64}},
|
||||
.tuning_params{.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}};
|
||||
using Builder = ckb::ConvBuilder<FwdConvSignature, algorithm, API_VERSION>;
|
||||
EXPECT_EQ(
|
||||
Builder::Instance::TypeString(),
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 64, Default, 32, 32, 2, 2, "
|
||||
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
|
||||
}
|
||||
|
||||
struct ConvFwdXdlBf16CompAlgorithm0
|
||||
{
|
||||
static constexpr ckb::ThreadBlock THREAD_BLOCK{
|
||||
.block_size = 256,
|
||||
.sub_matrix = {.m = 256, .n = 256, .k = 32},
|
||||
};
|
||||
static constexpr ckb::ConvTuningParams TUNING_PARAMS{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4,
|
||||
};
|
||||
};
|
||||
|
||||
TEST(ConvBuilderTest, GroupedConvFwdXdlBf16CompInstance0)
|
||||
{
|
||||
using Builder = ckb::ConvBuilder<FwdConvSignature, ConvFwdXdlBf16CompAlgorithm0, API_VERSION>;
|
||||
static constexpr FwdConvAlgorithm algorithm{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}},
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
using Builder = ckb::ConvBuilder<FwdConvSignature, algorithm, API_VERSION>;
|
||||
EXPECT_EQ(
|
||||
Builder::Instance::TypeString(),
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, "
|
||||
|
||||
Reference in New Issue
Block a user