diff --git a/experimental/builder/test/test_conv_builder.cpp b/experimental/builder/test/test_conv_builder.cpp index e416f58fa5..4be3949643 100644 --- a/experimental/builder/test/test_conv_builder.cpp +++ b/experimental/builder/test/test_conv_builder.cpp @@ -50,63 +50,110 @@ static_assert(ckb::HasABlockTransferInfo); static_assert(ckb::HasBBlockTransferInfo); static_assert(ckb::HasCBlockTransferInfo); -TEST(ConvBuilderTest, TestConvFwdXdlBf16CompInstances2xInstance0) +struct TestCase { - static constexpr FwdConvAlgorithm algorithm{ - .thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 128, .k = 64}}, - .tuning_params{.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }}; - using Builder = ckb::ConvBuilder; - EXPECT_EQ( - Builder::Instance::TypeString(), - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 64, Default, 32, 32, 2, 2, " - "8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>"); - EXPECT_EQ(Builder::factory::TUNING.ak1, 16); - EXPECT_EQ(Builder::factory::TUNING.bk1, 16); - EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[0], 4); - EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[1], 64); - EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[2], 1); - EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[0], 4); - EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[1], 64); - EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[2], 1); - EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[0], 1); - EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[1], 32); - EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[2], 1); - EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[3], 8); -} -TEST(ConvBuilderTest, GroupedConvFwdXdlBf16CompInstance0) -{ - static constexpr FwdConvAlgorithm algorithm{ - .thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}}, - .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }}; - using Builder = ckb::ConvBuilder; - EXPECT_EQ( - Builder::Instance::TypeString(), - "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, " - "8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>"); - EXPECT_EQ(Builder::factory::TUNING.ak1, 8); - EXPECT_EQ(Builder::factory::TUNING.bk1, 8); - EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[0], 4); - EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[1], 64); - EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[2], 1); - EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[0], 4); - EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[1], 64); - EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[2], 1); - EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[0], 1); - EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[1], 32); - EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[2], 1); - EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[3], 8); + std::string_view name; + FwdConvAlgorithm algorithm; + std::string_view expected_type; }; +constexpr std::array TEST_CASES = { + TestCase{ + .name = "ConvFwdXdlBf16CompInstances2x_0", + .algorithm = + {.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 128, .k = 64}}, + .tuning_params{.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}, + .block_transfer{ + .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, + .thread_cluster_lengths_c = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + }}, + .expected_type = + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 64, Default, 32, 32, " + "2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>", + }, + TestCase{ + .name = "GroupedConvFwdXdlBf16CompInstance0", + .algorithm = {.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}}, + .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}, + .block_transfer{ + .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, + .thread_cluster_lengths_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}, + }}, + .expected_type = + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, " + "4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>", + }}; + +static constexpr int NUM_TEST_CASES = std::tuple_size_v; + +// Helper to generate testing::Types, TestIndex<1>, ..., TestIndex>. +template +struct TestingIndices +{ + template + struct TestIndex + { + static constexpr int index = INDEX; + }; + template + static auto GenerateTypes(std::integer_sequence) + { + return ::testing::Types...>{}; + } + // testing::Types sequence of TestIndex types. + using Types = decltype(GenerateTypes(std::make_integer_sequence{})); +}; + +// A typed test suite so we can instantiate all the kernel builders. +template +class ConvBuilderTest : public ::testing::Test +{ + protected: + static constexpr int N = T::index; + static constexpr const std::string_view& NAME = TEST_CASES[N].name; + static constexpr auto& ALGORITHM = TEST_CASES[N].algorithm; + static constexpr const std::string_view& EXPECTED_TYPE = TEST_CASES[N].expected_type; +}; + +struct TestNameGenerator +{ + template + static std::string GetName(int index) + { + return std::to_string(index) + "." + std::string(TEST_CASES[index].name); + } +}; + +TYPED_TEST_SUITE(ConvBuilderTest, TestingIndices::Types, TestNameGenerator); + +// General test case, instantiated for each test case. +TYPED_TEST(ConvBuilderTest, TestInstance) +{ + static constexpr const FwdConvAlgorithm& ALGORITHM = ConvBuilderTest::ALGORITHM; + using Builder = ckb::ConvBuilder; + EXPECT_EQ(Builder::Instance::TypeString(), ConvBuilderTest::EXPECTED_TYPE); + const auto& tp = ALGORITHM.tuning_params; + EXPECT_EQ(Builder::factory::TUNING.ak1, tp.ak1); + EXPECT_EQ(Builder::factory::TUNING.bk1, tp.bk1); + const auto& tcla = ALGORITHM.block_transfer.thread_cluster_lengths_a; + EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[0], tcla.k0); + EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[1], tcla.m); + EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[2], tcla.k1); + const auto& tclb = ALGORITHM.block_transfer.thread_cluster_lengths_b; + EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[0], tclb.k0); + EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[1], tclb.n); + EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[2], tclb.k1); + const auto& tclc = ALGORITHM.block_transfer.thread_cluster_lengths_c; + EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[0], tclc.m_block); + EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[1], tclc.m_wave_per_xdl); + EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[2], tclc.n_block); + EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[3], tclc.n_wave_per_xdl); +} + } // namespace