From a2e661cd5736f5ba6855e328bc06bb72d768f031 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Tue, 2 Sep 2025 17:22:29 +0000 Subject: [PATCH] Making alorithm a non-type parameter This simplifies the design by continuing to reduce the number of types and avoidng extra use of constexpr. --- .../ck_tile/builder/conv_algorithm.hpp | 4 +- .../include/ck_tile/builder/conv_builder.hpp | 23 +++++--- .../include/ck_tile/builder/conv_factory.hpp | 28 +++++----- .../builder/test/test_conv_builder.cpp | 55 ++++++------------- 4 files changed, 50 insertions(+), 60 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp index dcad49dd75..6558350a1d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp @@ -37,7 +37,7 @@ static_assert(ThreadBlockInfo); // Concept to check if struct provides thread block info. template concept HasThreadBlockInfo = requires { - { T::THREAD_BLOCK } -> ThreadBlockInfo; + { T::thread_block } -> ThreadBlockInfo; }; // Concept for tuning parameters for a convolution problem. @@ -64,7 +64,7 @@ static_assert(ConvTuningInfo); // Concept to check if a struct provides convolution tuning info. template concept HasConvTuningInfo = requires { - { T::TUNING_PARAMS } -> ConvTuningInfo; + { T::tuning_params } -> ConvTuningInfo; }; // No requirements yet for a ConvAlogorithm concept. diff --git a/experimental/builder/include/ck_tile/builder/conv_builder.hpp b/experimental/builder/include/ck_tile/builder/conv_builder.hpp index 088ce34fba..7e7dfa9d29 100644 --- a/experimental/builder/include/ck_tile/builder/conv_builder.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_builder.hpp @@ -11,18 +11,25 @@ namespace ck_tile::builder { -template +/** + * @brief Top-level builder for creating convolution kernel instances. + * + * This struct serves as the main entry point for generating a convolution kernel. + * It uses a factory pattern based on the provided signature, algorithm, and version + * to construct the appropriate kernel instance. + * + * @tparam TSignature The convolution signature, which describes the mathematical functionality of + * the algorithm (e.g., data types, layouts, direction). + * @tparam ALGORITHM The specific convolution algorithm to be used for the implementation. + * @tparam Version The version of the builder implementation. + */ +template requires SupportedVersion struct ConvBuilder { - // Input: Signature describes the mathematical funcationality of the algorithm. - using Signature = TSignature; - // Input: Algorithm describes the implementation of the algorithm. - using Algorithm = TAlgorithm; - // Input: Version of the builder, exposed for testing. + using Signature = TSignature; static constexpr auto kVersion = Version; - // Implmentation: The factory handles the builder logic. - using builder = GroupedConvForwardXldCShuffleFactoryV3; + using builder = GroupedConvForwardXldCShuffleFactoryV3; // Output: The kernel class. using Instance = builder::Instance; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 23db8ab68b..6ffc05ca99 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -73,12 +73,13 @@ struct ConvBlock MNK per_block; }; -template +template constexpr ConvBlock SetThreadBlockInfo() { - if constexpr(HasThreadBlockInfo) + using AlgorithmType = decltype(ALGORITHM); + if constexpr(HasThreadBlockInfo) { - constexpr auto& TB = Algo::THREAD_BLOCK; + constexpr auto& TB = ALGORITHM.thread_block; return ConvBlock{ .block_size = TB.block_size, .per_block = {.m = TB.sub_matrix.m, .n = TB.sub_matrix.n, .k = TB.sub_matrix.k}}; @@ -101,19 +102,20 @@ struct ConvTuning int n_xdl_per_wave = 0; }; -template +template constexpr ConvTuning SetConvTuningInfo() { - if constexpr(HasConvTuningInfo) + using AlgorithmType = decltype(ALGORITHM); + if constexpr(HasConvTuningInfo) { - constexpr auto TI = Algo::TUNING_PARAMS; + constexpr auto& TP = ALGORITHM.tuning_params; return ConvTuning{ - .ak1 = TI.ak1, - .bk1 = TI.bk1, + .ak1 = TP.ak1, + .bk1 = TP.bk1, .m_per_xdl = 32, .n_per_dxl = 32, - .m_xdl_per_wave = TI.m_xdl_per_wave, - .n_xdl_per_wave = TI.n_xdl_per_wave, + .m_xdl_per_wave = TP.m_xdl_per_wave, + .n_xdl_per_wave = TP.n_xdl_per_wave, }; } // Default values. @@ -149,7 +151,7 @@ struct CBlockTransfer }; // Factory builds an instance of a grouped convolution kernel. -template +template requires SupportedVersion struct GroupedConvForwardXldCShuffleFactoryV3 { @@ -161,8 +163,8 @@ struct GroupedConvForwardXldCShuffleFactoryV3 .conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, .gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding, }; - static constexpr ConvBlock BLOCK = SetThreadBlockInfo(); - static constexpr ConvTuning TUNING = SetConvTuningInfo(); + static constexpr ConvBlock BLOCK = SetThreadBlockInfo(); + static constexpr ConvTuning TUNING = SetConvTuningInfo(); static constexpr BlockTransfer A_BLOCK_TRANSFER{ .thread_cluster_lengths = {4, 64, 1}, .thread_cluster_order = {1, 0, 2}, diff --git a/experimental/builder/test/test_conv_builder.cpp b/experimental/builder/test/test_conv_builder.cpp index 57a86ac65d..391db57010 100644 --- a/experimental/builder/test/test_conv_builder.cpp +++ b/experimental/builder/test/test_conv_builder.cpp @@ -15,68 +15,49 @@ struct FwdConvSignature }; static_assert(ckb::ConvSignature); -struct FwdConvAlgorithm +struct DefaultFwdConvAlgorithm { }; -static_assert(ckb::ConvAlgorithm); +static_assert(ckb::ConvAlgorithm); static constexpr char API_VERSION[] = "0.1.0"; TEST(ConvBuilderTest, TestDefaultInstance) { - using Builder = ckb::ConvBuilder; + static constexpr DefaultFwdConvAlgorithm algorithm; + using Builder = ckb::ConvBuilder; EXPECT_EQ( Builder::Instance::TypeString(), "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, " "8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>"); } -[[maybe_unused]] static constexpr ckb::ThreadBlock THREAD_BLOCK_256_256_256_32{ - .block_size = 256, - .sub_matrix = {.m = 256, .n = 256, .k = 32}, -}; - -struct ConvFwdXdlBf16CompInstances2xAlgorithm0 +struct FwdConvAlgorithm { - static constexpr ckb::ThreadBlock THREAD_BLOCK{ - .block_size = 256, - .sub_matrix = {.m = 256, .n = 128, .k = 64}, - }; - static constexpr ckb::ConvTuningParams TUNING_PARAMS{ - .ak1 = 16, - .bk1 = 16, - .m_xdl_per_wave = 2, - .n_xdl_per_wave = 2, - }; + ckb::ThreadBlock thread_block; + ckb::ConvTuningParams tuning_params; }; +static_assert(ckb::ConvAlgorithm); +static_assert(ckb::HasThreadBlockInfo); +static_assert(ckb::HasConvTuningInfo); TEST(ConvBuilderTest, TestConvFwdXdlBf16CompInstances2xInstance0) { - using Builder = - ckb::ConvBuilder; + static constexpr FwdConvAlgorithm algorithm{ + .thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 128, .k = 64}}, + .tuning_params{.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}}; + using Builder = ckb::ConvBuilder; EXPECT_EQ( Builder::Instance::TypeString(), "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 64, Default, 32, 32, 2, 2, " "8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>"); } - -struct ConvFwdXdlBf16CompAlgorithm0 -{ - static constexpr ckb::ThreadBlock THREAD_BLOCK{ - .block_size = 256, - .sub_matrix = {.m = 256, .n = 256, .k = 32}, - }; - static constexpr ckb::ConvTuningParams TUNING_PARAMS{ - .ak1 = 8, - .bk1 = 8, - .m_xdl_per_wave = 4, - .n_xdl_per_wave = 4, - }; -}; - TEST(ConvBuilderTest, GroupedConvFwdXdlBf16CompInstance0) { - using Builder = ckb::ConvBuilder; + static constexpr FwdConvAlgorithm algorithm{ + .thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}}, + .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; + using Builder = ckb::ConvBuilder; EXPECT_EQ( Builder::Instance::TypeString(), "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, "