Making alorithm a non-type parameter

This simplifies the design by continuing to reduce the number of types and avoidng extra use of constexpr.
This commit is contained in:
John Shumway
2025-09-02 17:22:29 +00:00
parent a79616f323
commit a2e661cd57
4 changed files with 50 additions and 60 deletions

View File

@@ -37,7 +37,7 @@ static_assert(ThreadBlockInfo<ThreadBlock>);
// Concept to check if struct provides thread block info.
template <typename T>
concept HasThreadBlockInfo = requires {
{ T::THREAD_BLOCK } -> ThreadBlockInfo;
{ T::thread_block } -> ThreadBlockInfo;
};
// Concept for tuning parameters for a convolution problem.
@@ -64,7 +64,7 @@ static_assert(ConvTuningInfo<ConvTuningParams>);
// Concept to check if a struct provides convolution tuning info.
template <typename T>
concept HasConvTuningInfo = requires {
{ T::TUNING_PARAMS } -> ConvTuningInfo;
{ T::tuning_params } -> ConvTuningInfo;
};
// No requirements yet for a ConvAlogorithm concept.

View File

@@ -11,18 +11,25 @@
namespace ck_tile::builder {
template <ConvSignature TSignature, ConvAlgorithm TAlgorithm, auto Version>
/**
* @brief Top-level builder for creating convolution kernel instances.
*
* This struct serves as the main entry point for generating a convolution kernel.
* It uses a factory pattern based on the provided signature, algorithm, and version
* to construct the appropriate kernel instance.
*
* @tparam TSignature The convolution signature, which describes the mathematical functionality of
* the algorithm (e.g., data types, layouts, direction).
* @tparam ALGORITHM The specific convolution algorithm to be used for the implementation.
* @tparam Version The version of the builder implementation.
*/
template <ConvSignature TSignature, ConvAlgorithm auto ALGORITHM, auto Version>
requires SupportedVersion<Version>
struct ConvBuilder
{
// Input: Signature describes the mathematical funcationality of the algorithm.
using Signature = TSignature;
// Input: Algorithm describes the implementation of the algorithm.
using Algorithm = TAlgorithm;
// Input: Version of the builder, exposed for testing.
using Signature = TSignature;
static constexpr auto kVersion = Version;
// Implmentation: The factory handles the builder logic.
using builder = GroupedConvForwardXldCShuffleFactoryV3<Signature, Algorithm, Version>;
using builder = GroupedConvForwardXldCShuffleFactoryV3<Signature, ALGORITHM, Version>;
// Output: The kernel class.
using Instance = builder::Instance;
};

View File

@@ -73,12 +73,13 @@ struct ConvBlock
MNK<int> per_block;
};
template <ConvAlgorithm Algo>
template <ConvAlgorithm auto ALGORITHM>
constexpr ConvBlock SetThreadBlockInfo()
{
if constexpr(HasThreadBlockInfo<Algo>)
using AlgorithmType = decltype(ALGORITHM);
if constexpr(HasThreadBlockInfo<AlgorithmType>)
{
constexpr auto& TB = Algo::THREAD_BLOCK;
constexpr auto& TB = ALGORITHM.thread_block;
return ConvBlock{
.block_size = TB.block_size,
.per_block = {.m = TB.sub_matrix.m, .n = TB.sub_matrix.n, .k = TB.sub_matrix.k}};
@@ -101,19 +102,20 @@ struct ConvTuning
int n_xdl_per_wave = 0;
};
template <ConvAlgorithm Algo>
template <ConvAlgorithm auto ALGORITHM>
constexpr ConvTuning SetConvTuningInfo()
{
if constexpr(HasConvTuningInfo<Algo>)
using AlgorithmType = decltype(ALGORITHM);
if constexpr(HasConvTuningInfo<AlgorithmType>)
{
constexpr auto TI = Algo::TUNING_PARAMS;
constexpr auto& TP = ALGORITHM.tuning_params;
return ConvTuning{
.ak1 = TI.ak1,
.bk1 = TI.bk1,
.ak1 = TP.ak1,
.bk1 = TP.bk1,
.m_per_xdl = 32,
.n_per_dxl = 32,
.m_xdl_per_wave = TI.m_xdl_per_wave,
.n_xdl_per_wave = TI.n_xdl_per_wave,
.m_xdl_per_wave = TP.m_xdl_per_wave,
.n_xdl_per_wave = TP.n_xdl_per_wave,
};
}
// Default values.
@@ -149,7 +151,7 @@ struct CBlockTransfer
};
// Factory builds an instance of a grouped convolution kernel.
template <ConvSignature Signature, ConvAlgorithm Algorithm, auto Version>
template <ConvSignature Signature, ConvAlgorithm auto ALGORITHM, auto Version>
requires SupportedVersion<Version>
struct GroupedConvForwardXldCShuffleFactoryV3
{
@@ -161,8 +163,8 @@ struct GroupedConvForwardXldCShuffleFactoryV3
.conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
};
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<Algorithm>();
static constexpr ConvTuning TUNING = SetConvTuningInfo<Algorithm>();
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
static constexpr ConvTuning TUNING = SetConvTuningInfo<ALGORITHM>();
static constexpr BlockTransfer A_BLOCK_TRANSFER{
.thread_cluster_lengths = {4, 64, 1},
.thread_cluster_order = {1, 0, 2},

View File

@@ -15,68 +15,49 @@ struct FwdConvSignature
};
static_assert(ckb::ConvSignature<FwdConvSignature>);
struct FwdConvAlgorithm
struct DefaultFwdConvAlgorithm
{
};
static_assert(ckb::ConvAlgorithm<FwdConvAlgorithm>);
static_assert(ckb::ConvAlgorithm<DefaultFwdConvAlgorithm>);
static constexpr char API_VERSION[] = "0.1.0";
TEST(ConvBuilderTest, TestDefaultInstance)
{
using Builder = ckb::ConvBuilder<FwdConvSignature, FwdConvAlgorithm, API_VERSION>;
static constexpr DefaultFwdConvAlgorithm algorithm;
using Builder = ckb::ConvBuilder<FwdConvSignature, algorithm, API_VERSION>;
EXPECT_EQ(
Builder::Instance::TypeString(),
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, "
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
}
[[maybe_unused]] static constexpr ckb::ThreadBlock THREAD_BLOCK_256_256_256_32{
.block_size = 256,
.sub_matrix = {.m = 256, .n = 256, .k = 32},
};
struct ConvFwdXdlBf16CompInstances2xAlgorithm0
struct FwdConvAlgorithm
{
static constexpr ckb::ThreadBlock THREAD_BLOCK{
.block_size = 256,
.sub_matrix = {.m = 256, .n = 128, .k = 64},
};
static constexpr ckb::ConvTuningParams TUNING_PARAMS{
.ak1 = 16,
.bk1 = 16,
.m_xdl_per_wave = 2,
.n_xdl_per_wave = 2,
};
ckb::ThreadBlock thread_block;
ckb::ConvTuningParams tuning_params;
};
static_assert(ckb::ConvAlgorithm<FwdConvAlgorithm>);
static_assert(ckb::HasThreadBlockInfo<FwdConvAlgorithm>);
static_assert(ckb::HasConvTuningInfo<FwdConvAlgorithm>);
TEST(ConvBuilderTest, TestConvFwdXdlBf16CompInstances2xInstance0)
{
using Builder =
ckb::ConvBuilder<FwdConvSignature, ConvFwdXdlBf16CompInstances2xAlgorithm0, API_VERSION>;
static constexpr FwdConvAlgorithm algorithm{
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 128, .k = 64}},
.tuning_params{.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}};
using Builder = ckb::ConvBuilder<FwdConvSignature, algorithm, API_VERSION>;
EXPECT_EQ(
Builder::Instance::TypeString(),
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 64, Default, 32, 32, 2, 2, "
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
}
struct ConvFwdXdlBf16CompAlgorithm0
{
static constexpr ckb::ThreadBlock THREAD_BLOCK{
.block_size = 256,
.sub_matrix = {.m = 256, .n = 256, .k = 32},
};
static constexpr ckb::ConvTuningParams TUNING_PARAMS{
.ak1 = 8,
.bk1 = 8,
.m_xdl_per_wave = 4,
.n_xdl_per_wave = 4,
};
};
TEST(ConvBuilderTest, GroupedConvFwdXdlBf16CompInstance0)
{
using Builder = ckb::ConvBuilder<FwdConvSignature, ConvFwdXdlBf16CompAlgorithm0, API_VERSION>;
static constexpr FwdConvAlgorithm algorithm{
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}},
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
using Builder = ckb::ConvBuilder<FwdConvSignature, algorithm, API_VERSION>;
EXPECT_EQ(
Builder::Instance::TypeString(),
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, "