diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp index d45186ece6..dcad49dd75 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp @@ -5,15 +5,15 @@ namespace ck_tile::builder { -// Convenience struct for a tuple of m, n, and k values. +// Convenience struct for a tuple of m, n, and k values. template -struct MNK { +struct MNK +{ T m{}; T n{}; T k{}; }; - // Concept for thread block info for a GEMM problem. template concept ThreadBlockInfo = requires(T t) { @@ -23,9 +23,9 @@ concept ThreadBlockInfo = requires(T t) { { t.sub_matrix.k } -> std::convertible_to; }; - // Describe a thread block for a GEMM. -struct ThreadBlock { +struct ThreadBlock +{ // Thread block size. int block_size; // Size of the submatrix problem in a thread block. @@ -40,6 +40,32 @@ concept HasThreadBlockInfo = requires { { T::THREAD_BLOCK } -> ThreadBlockInfo; }; +// Concept for tuning parameters for a convolution problem. +template +concept ConvTuningInfo = requires(T t) { + { t.ak1 } -> std::convertible_to; + { t.bk1 } -> std::convertible_to; + { t.m_xdl_per_wave } -> std::convertible_to; + { t.n_xdl_per_wave } -> std::convertible_to; +}; + +// Describe some convolution tuning parameters. +struct ConvTuningParams +{ + // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! + int ak1 = 0; + int bk1 = 0; + int m_xdl_per_wave = 0; + int n_xdl_per_wave = 0; +}; + +static_assert(ConvTuningInfo); + +// Concept to check if a struct provides convolution tuning info. +template +concept HasConvTuningInfo = requires { + { T::TUNING_PARAMS } -> ConvTuningInfo; +}; // No requirements yet for a ConvAlogorithm concept. template diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 72011fce9f..23db8ab68b 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -94,13 +94,39 @@ constexpr ConvBlock SetThreadBlockInfo() struct ConvTuning { int ak1 = 0; - int ak2 = 0; + int bk1 = 0; int m_per_xdl = 0; int n_per_dxl = 0; int m_xdl_per_wave = 0; int n_xdl_per_wave = 0; }; +template +constexpr ConvTuning SetConvTuningInfo() +{ + if constexpr(HasConvTuningInfo) + { + constexpr auto TI = Algo::TUNING_PARAMS; + return ConvTuning{ + .ak1 = TI.ak1, + .bk1 = TI.bk1, + .m_per_xdl = 32, + .n_per_dxl = 32, + .m_xdl_per_wave = TI.m_xdl_per_wave, + .n_xdl_per_wave = TI.n_xdl_per_wave, + }; + } + // Default values. + return ConvTuning{ + .ak1 = 8, + .bk1 = 8, + .m_per_xdl = 32, + .n_per_dxl = 32, + .m_xdl_per_wave = 4, + .n_xdl_per_wave = 4, + }; +} + // Block tranfser paramters for A or B tensor. struct BlockTransfer { @@ -135,15 +161,8 @@ struct GroupedConvForwardXldCShuffleFactoryV3 .conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, .gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding, }; - static constexpr ConvBlock BLOCK = SetThreadBlockInfo(); - static constexpr ConvTuning TUNING{ - .ak1 = 8, - .ak2 = 8, - .m_per_xdl = 32, - .n_per_dxl = 32, - .m_xdl_per_wave = 4, - .n_xdl_per_wave = 4, - }; + static constexpr ConvBlock BLOCK = SetThreadBlockInfo(); + static constexpr ConvTuning TUNING = SetConvTuningInfo(); static constexpr BlockTransfer A_BLOCK_TRANSFER{ .thread_cluster_lengths = {4, 64, 1}, .thread_cluster_order = {1, 0, 2}, @@ -194,7 +213,7 @@ struct GroupedConvForwardXldCShuffleFactoryV3 BLOCK.per_block.n, BLOCK.per_block.k, TUNING.ak1, - TUNING.ak2, + TUNING.bk1, TUNING.m_per_xdl, TUNING.n_per_dxl, TUNING.m_xdl_per_wave, diff --git a/experimental/builder/test/test_conv_builder.cpp b/experimental/builder/test/test_conv_builder.cpp index 0368e4bb79..57a86ac65d 100644 --- a/experimental/builder/test/test_conv_builder.cpp +++ b/experimental/builder/test/test_conv_builder.cpp @@ -17,26 +17,36 @@ static_assert(ckb::ConvSignature); struct FwdConvAlgorithm { - // TODO: Add algorithm info. }; static_assert(ckb::ConvAlgorithm); static constexpr char API_VERSION[] = "0.1.0"; -using FwdConvBuilder = ckb::ConvBuilder; TEST(ConvBuilderTest, TestDefaultInstance) { + using Builder = ckb::ConvBuilder; EXPECT_EQ( - FwdConvBuilder::Instance::TypeString(), + Builder::Instance::TypeString(), "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, " "8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>"); } +[[maybe_unused]] static constexpr ckb::ThreadBlock THREAD_BLOCK_256_256_256_32{ + .block_size = 256, + .sub_matrix = {.m = 256, .n = 256, .k = 32}, +}; + struct ConvFwdXdlBf16CompInstances2xAlgorithm0 { static constexpr ckb::ThreadBlock THREAD_BLOCK{ .block_size = 256, - .sub_matrix = {.m = 256, .n = 256, .k = 32}, + .sub_matrix = {.m = 256, .n = 128, .k = 64}, + }; + static constexpr ckb::ConvTuningParams TUNING_PARAMS{ + .ak1 = 16, + .bk1 = 16, + .m_xdl_per_wave = 2, + .n_xdl_per_wave = 2, }; }; @@ -46,8 +56,31 @@ TEST(ConvBuilderTest, TestConvFwdXdlBf16CompInstances2xInstance0) ckb::ConvBuilder; EXPECT_EQ( Builder::Instance::TypeString(), - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, " + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 64, Default, 32, 32, 2, 2, " "8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>"); } +struct ConvFwdXdlBf16CompAlgorithm0 +{ + static constexpr ckb::ThreadBlock THREAD_BLOCK{ + .block_size = 256, + .sub_matrix = {.m = 256, .n = 256, .k = 32}, + }; + static constexpr ckb::ConvTuningParams TUNING_PARAMS{ + .ak1 = 8, + .bk1 = 8, + .m_xdl_per_wave = 4, + .n_xdl_per_wave = 4, + }; +}; + +TEST(ConvBuilderTest, GroupedConvFwdXdlBf16CompInstance0) +{ + using Builder = ckb::ConvBuilder; + EXPECT_EQ( + Builder::Instance::TypeString(), + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, " + "8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>"); +}; + } // namespace