diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp index 7685348ef0..990bfb093d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp @@ -19,9 +19,9 @@ struct MNK template concept ThreadBlockInfo = requires(T t) { { t.block_size } -> std::convertible_to; - { t.sub_matrix.m } -> std::convertible_to; - { t.sub_matrix.n } -> std::convertible_to; - { t.sub_matrix.k } -> std::convertible_to; + { t.submatrix.m } -> std::convertible_to; + { t.submatrix.n } -> std::convertible_to; + { t.submatrix.k } -> std::convertible_to; }; // Describe a thread block for a GEMM. @@ -30,7 +30,7 @@ struct ThreadBlock // Thread block size. int block_size; // Size of the submatrix problem in a thread block. - MNK sub_matrix; + MNK submatrix; }; static_assert(ThreadBlockInfo); diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index c7da7fde2b..cb9ab0fb9e 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -82,7 +82,7 @@ constexpr ConvBlock SetThreadBlockInfo() constexpr auto& TB = ALGORITHM.thread_block; return ConvBlock{ .block_size = TB.block_size, - .per_block = {.m = TB.sub_matrix.m, .n = TB.sub_matrix.n, .k = TB.sub_matrix.k}}; + .per_block = {.m = TB.submatrix.m, .n = TB.submatrix.n, .k = TB.submatrix.k}}; } // Default values if thread block info isn't specified. return ConvBlock{ diff --git a/experimental/builder/test/test_conv_instances.cpp b/experimental/builder/test/test_conv_instances.cpp index c340ec5aeb..5b3798a54b 100644 --- a/experimental/builder/test/test_conv_instances.cpp +++ b/experimental/builder/test/test_conv_instances.cpp @@ -44,23 +44,27 @@ struct TestCase std::string_view expected_type; }; +// Helper function to set the sub_matrix size. +constexpr ckb::ThreadBlock set_submatrix(int m, int n, int k) +{ + return {.block_size = 256, .submatrix = {.m = m, .n = n, .k = k}}; +} + // Test cases to drive the typed test suite. constexpr std::array TEST_CASES = { TestCase{ // double rate mfma instances on gfx950 .name = "ConvFwdXdlBf16CompInstances2x_0", .algorithm = - { - .thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 128, .k = 64}}, - .tuning_params{.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }, - .pipeline_version = ckb::BlockGemmPipelineVersion::V4, - }, + {.thread_block = set_submatrix(256, 128, 64), + .tuning_params = {.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}, + .block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, + .thread_cluster_lengths_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}}, + .pipeline_version = ckb::BlockGemmPipelineVersion::V4}, .expected_type = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 64, Default, 32, 32, " "2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>", @@ -69,17 +73,15 @@ constexpr std::array TEST_CASES = { // Compute-friendly. .name = "GroupedConvFwdXdlBf16CompInstance0", .algorithm = - { - .thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}}, - .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }, - .pipeline_version = ckb::BlockGemmPipelineVersion::V4, - }, + {.thread_block = set_submatrix(256, 256, 32), + .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}, + .block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, + .thread_cluster_lengths_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}}, + .pipeline_version = ckb::BlockGemmPipelineVersion::V4}, .expected_type = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, " "4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>", @@ -87,17 +89,15 @@ constexpr std::array TEST_CASES = { TestCase{ .name = "GroupedConvFwdXdlBf16CompInstance1", .algorithm = - { - .thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 128, .k = 64}}, - .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 8, .m = 32, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 8, .n = 32, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }, - .pipeline_version = ckb::BlockGemmPipelineVersion::V4, - }, + {.thread_block = set_submatrix(128, 128, 64), + .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}, + .block_transfer = {.thread_cluster_lengths_a = {.k0 = 8, .m = 32, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 8, .n = 32, .k1 = 1}, + .thread_cluster_lengths_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}}, + .pipeline_version = ckb::BlockGemmPipelineVersion::V4}, .expected_type = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, " "2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>", @@ -105,17 +105,15 @@ constexpr std::array TEST_CASES = { TestCase{ .name = "GroupedConvFwdXdlBf16CompInstance2", .algorithm = - { - .thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 128, .k = 32}}, - .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }, - .pipeline_version = ckb::BlockGemmPipelineVersion::V4, - }, + {.thread_block = set_submatrix(128, 128, 32), + .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}, + .block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, + .thread_cluster_lengths_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}}, + .pipeline_version = ckb::BlockGemmPipelineVersion::V4}, .expected_type = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Default, 32, 32, " "2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>", @@ -123,17 +121,15 @@ constexpr std::array TEST_CASES = { TestCase{ .name = "GroupedConvFwdXdlBf16CompInstance3", .algorithm = - { - .thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}}, - .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }, - .pipeline_version = ckb::BlockGemmPipelineVersion::V3, - }, + {.thread_block = set_submatrix(256, 256, 32), + .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}, + .block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, + .thread_cluster_lengths_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}}, + .pipeline_version = ckb::BlockGemmPipelineVersion::V3}, .expected_type = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, " "4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>", @@ -141,17 +137,15 @@ constexpr std::array TEST_CASES = { TestCase{ .name = "GroupedConvFwdXdlBf16CompInstance4", .algorithm = - { - .thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}}, - .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }, - .pipeline_version = ckb::BlockGemmPipelineVersion::V5, - }, + {.thread_block = set_submatrix(256, 256, 32), + .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}, + .block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, + .thread_cluster_lengths_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}}, + .pipeline_version = ckb::BlockGemmPipelineVersion::V5}, .expected_type = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, " "4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>", @@ -160,14 +154,14 @@ constexpr std::array TEST_CASES = { .name = "GroupedConvFwdXdlBf16CompInstance5", .algorithm = { - .thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 128, .k = 32}}, - .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }, + .thread_block = set_submatrix(256, 128, 32), + .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4}, + .block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, + .thread_cluster_lengths_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}}, .pipeline_version = ckb::BlockGemmPipelineVersion::V1, }, .expected_type = @@ -178,14 +172,14 @@ constexpr std::array TEST_CASES = { .name = "GroupedConvFwdXdlBf16CompInstance7", .algorithm = { - .thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 256, .k = 32}}, - .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }, + .thread_block = set_submatrix(128, 256, 32), + .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4}, + .block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, + .thread_cluster_lengths_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}}, .pipeline_version = ckb::BlockGemmPipelineVersion::V1, }, .expected_type = @@ -196,14 +190,14 @@ constexpr std::array TEST_CASES = { .name = "GroupedConvFwdXdlBf16CompInstance8", .algorithm = { - .thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 128, .k = 64}}, - .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }, + .thread_block = set_submatrix(128, 128, 64), + .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4}, + .block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, + .thread_cluster_lengths_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}}, .pipeline_version = ckb::BlockGemmPipelineVersion::V1, }, .expected_type = @@ -214,14 +208,14 @@ constexpr std::array TEST_CASES = { .name = "GroupedConvFwdXdlBf16CompInstance9", .algorithm = { - .thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 64, .k = 64}}, - .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }, + .thread_block = set_submatrix(128, 64, 64), + .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4}, + .block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, + .thread_cluster_lengths_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}}, .pipeline_version = ckb::BlockGemmPipelineVersion::V3, }, .expected_type = @@ -232,14 +226,14 @@ constexpr std::array TEST_CASES = { .name = "GroupedConvFwdXdlBf16CompInstance9", .algorithm = { - .thread_block{.block_size = 256, .sub_matrix = {.m = 64, .n = 128, .k = 64}}, - .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }, + .thread_block = set_submatrix(64, 128, 64), + .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4}, + .block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, + .thread_cluster_lengths_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}}, .pipeline_version = ckb::BlockGemmPipelineVersion::V3, }, .expected_type = @@ -250,14 +244,14 @@ constexpr std::array TEST_CASES = { .name = "GroupedConvFwdXdlBf16CompInstance9", .algorithm = { - .thread_block{.block_size = 256, .sub_matrix = {.m = 64, .n = 64, .k = 32}}, - .tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4}, - .block_transfer{ - .thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, - .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, - .thread_cluster_lengths_c = - {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - }, + .thread_block = set_submatrix(64, 64, 32), + .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4}, + .block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1}, + .thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1}, + .thread_cluster_lengths_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}}, .pipeline_version = ckb::BlockGemmPipelineVersion::V3, }, .expected_type =