Use a set_submatrix helper.

We want to simplify the test of lots of instances, so adding a helper to make the test and instantiation details more clear.
This commit is contained in:
John Shumway
2025-09-06 16:46:57 +00:00
parent 5da397b9ec
commit 8b540c8df1
3 changed files with 113 additions and 119 deletions

View File

@@ -19,9 +19,9 @@ struct MNK
template <typename T>
concept ThreadBlockInfo = requires(T t) {
{ t.block_size } -> std::convertible_to<int>;
{ t.sub_matrix.m } -> std::convertible_to<int>;
{ t.sub_matrix.n } -> std::convertible_to<int>;
{ t.sub_matrix.k } -> std::convertible_to<int>;
{ t.submatrix.m } -> std::convertible_to<int>;
{ t.submatrix.n } -> std::convertible_to<int>;
{ t.submatrix.k } -> std::convertible_to<int>;
};
// Describe a thread block for a GEMM.
@@ -30,7 +30,7 @@ struct ThreadBlock
// Thread block size.
int block_size;
// Size of the submatrix problem in a thread block.
MNK<int> sub_matrix;
MNK<int> submatrix;
};
static_assert(ThreadBlockInfo<ThreadBlock>);

View File

@@ -82,7 +82,7 @@ constexpr ConvBlock SetThreadBlockInfo()
constexpr auto& TB = ALGORITHM.thread_block;
return ConvBlock{
.block_size = TB.block_size,
.per_block = {.m = TB.sub_matrix.m, .n = TB.sub_matrix.n, .k = TB.sub_matrix.k}};
.per_block = {.m = TB.submatrix.m, .n = TB.submatrix.n, .k = TB.submatrix.k}};
}
// Default values if thread block info isn't specified.
return ConvBlock{

View File

@@ -44,23 +44,27 @@ struct TestCase
std::string_view expected_type;
};
// Helper function to set the sub_matrix size.
constexpr ckb::ThreadBlock set_submatrix(int m, int n, int k)
{
return {.block_size = 256, .submatrix = {.m = m, .n = n, .k = k}};
}
// Test cases to drive the typed test suite.
constexpr std::array TEST_CASES = {
TestCase{
// double rate mfma instances on gfx950
.name = "ConvFwdXdlBf16CompInstances2x_0",
.algorithm =
{
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 128, .k = 64}},
.tuning_params{.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
},
.pipeline_version = ckb::BlockGemmPipelineVersion::V4,
},
{.thread_block = set_submatrix(256, 128, 64),
.tuning_params = {.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8}},
.pipeline_version = ckb::BlockGemmPipelineVersion::V4},
.expected_type =
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 64, Default, 32, 32, "
"2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>",
@@ -69,17 +73,15 @@ constexpr std::array TEST_CASES = {
// Compute-friendly.
.name = "GroupedConvFwdXdlBf16CompInstance0",
.algorithm =
{
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}},
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
},
.pipeline_version = ckb::BlockGemmPipelineVersion::V4,
},
{.thread_block = set_submatrix(256, 256, 32),
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8}},
.pipeline_version = ckb::BlockGemmPipelineVersion::V4},
.expected_type =
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, "
"4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>",
@@ -87,17 +89,15 @@ constexpr std::array TEST_CASES = {
TestCase{
.name = "GroupedConvFwdXdlBf16CompInstance1",
.algorithm =
{
.thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 128, .k = 64}},
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 8, .m = 32, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 8, .n = 32, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
},
.pipeline_version = ckb::BlockGemmPipelineVersion::V4,
},
{.thread_block = set_submatrix(128, 128, 64),
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 8, .m = 32, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 8, .n = 32, .k1 = 1},
.thread_cluster_lengths_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8}},
.pipeline_version = ckb::BlockGemmPipelineVersion::V4},
.expected_type =
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, "
"2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>",
@@ -105,17 +105,15 @@ constexpr std::array TEST_CASES = {
TestCase{
.name = "GroupedConvFwdXdlBf16CompInstance2",
.algorithm =
{
.thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 128, .k = 32}},
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
},
.pipeline_version = ckb::BlockGemmPipelineVersion::V4,
},
{.thread_block = set_submatrix(128, 128, 32),
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8}},
.pipeline_version = ckb::BlockGemmPipelineVersion::V4},
.expected_type =
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Default, 32, 32, "
"2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>",
@@ -123,17 +121,15 @@ constexpr std::array TEST_CASES = {
TestCase{
.name = "GroupedConvFwdXdlBf16CompInstance3",
.algorithm =
{
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}},
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
},
.pipeline_version = ckb::BlockGemmPipelineVersion::V3,
},
{.thread_block = set_submatrix(256, 256, 32),
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8}},
.pipeline_version = ckb::BlockGemmPipelineVersion::V3},
.expected_type =
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, "
"4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>",
@@ -141,17 +137,15 @@ constexpr std::array TEST_CASES = {
TestCase{
.name = "GroupedConvFwdXdlBf16CompInstance4",
.algorithm =
{
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}},
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
},
.pipeline_version = ckb::BlockGemmPipelineVersion::V5,
},
{.thread_block = set_submatrix(256, 256, 32),
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8}},
.pipeline_version = ckb::BlockGemmPipelineVersion::V5},
.expected_type =
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, "
"4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>",
@@ -160,14 +154,14 @@ constexpr std::array TEST_CASES = {
.name = "GroupedConvFwdXdlBf16CompInstance5",
.algorithm =
{
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 128, .k = 32}},
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
},
.thread_block = set_submatrix(256, 128, 32),
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8}},
.pipeline_version = ckb::BlockGemmPipelineVersion::V1,
},
.expected_type =
@@ -178,14 +172,14 @@ constexpr std::array TEST_CASES = {
.name = "GroupedConvFwdXdlBf16CompInstance7",
.algorithm =
{
.thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 256, .k = 32}},
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
},
.thread_block = set_submatrix(128, 256, 32),
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8}},
.pipeline_version = ckb::BlockGemmPipelineVersion::V1,
},
.expected_type =
@@ -196,14 +190,14 @@ constexpr std::array TEST_CASES = {
.name = "GroupedConvFwdXdlBf16CompInstance8",
.algorithm =
{
.thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 128, .k = 64}},
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
},
.thread_block = set_submatrix(128, 128, 64),
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8}},
.pipeline_version = ckb::BlockGemmPipelineVersion::V1,
},
.expected_type =
@@ -214,14 +208,14 @@ constexpr std::array TEST_CASES = {
.name = "GroupedConvFwdXdlBf16CompInstance9",
.algorithm =
{
.thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 64, .k = 64}},
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
},
.thread_block = set_submatrix(128, 64, 64),
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8}},
.pipeline_version = ckb::BlockGemmPipelineVersion::V3,
},
.expected_type =
@@ -232,14 +226,14 @@ constexpr std::array TEST_CASES = {
.name = "GroupedConvFwdXdlBf16CompInstance9",
.algorithm =
{
.thread_block{.block_size = 256, .sub_matrix = {.m = 64, .n = 128, .k = 64}},
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
},
.thread_block = set_submatrix(64, 128, 64),
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8}},
.pipeline_version = ckb::BlockGemmPipelineVersion::V3,
},
.expected_type =
@@ -250,14 +244,14 @@ constexpr std::array TEST_CASES = {
.name = "GroupedConvFwdXdlBf16CompInstance9",
.algorithm =
{
.thread_block{.block_size = 256, .sub_matrix = {.m = 64, .n = 64, .k = 32}},
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
},
.thread_block = set_submatrix(64, 64, 32),
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8}},
.pipeline_version = ckb::BlockGemmPipelineVersion::V3,
},
.expected_type =