mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Use a set_submatrix helper.
We want to simplify the test of lots of instances, so adding a helper to make the test and instantiation details more clear.
This commit is contained in:
@@ -19,9 +19,9 @@ struct MNK
|
||||
template <typename T>
|
||||
concept ThreadBlockInfo = requires(T t) {
|
||||
{ t.block_size } -> std::convertible_to<int>;
|
||||
{ t.sub_matrix.m } -> std::convertible_to<int>;
|
||||
{ t.sub_matrix.n } -> std::convertible_to<int>;
|
||||
{ t.sub_matrix.k } -> std::convertible_to<int>;
|
||||
{ t.submatrix.m } -> std::convertible_to<int>;
|
||||
{ t.submatrix.n } -> std::convertible_to<int>;
|
||||
{ t.submatrix.k } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
// Describe a thread block for a GEMM.
|
||||
@@ -30,7 +30,7 @@ struct ThreadBlock
|
||||
// Thread block size.
|
||||
int block_size;
|
||||
// Size of the submatrix problem in a thread block.
|
||||
MNK<int> sub_matrix;
|
||||
MNK<int> submatrix;
|
||||
};
|
||||
static_assert(ThreadBlockInfo<ThreadBlock>);
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ constexpr ConvBlock SetThreadBlockInfo()
|
||||
constexpr auto& TB = ALGORITHM.thread_block;
|
||||
return ConvBlock{
|
||||
.block_size = TB.block_size,
|
||||
.per_block = {.m = TB.sub_matrix.m, .n = TB.sub_matrix.n, .k = TB.sub_matrix.k}};
|
||||
.per_block = {.m = TB.submatrix.m, .n = TB.submatrix.n, .k = TB.submatrix.k}};
|
||||
}
|
||||
// Default values if thread block info isn't specified.
|
||||
return ConvBlock{
|
||||
|
||||
@@ -44,23 +44,27 @@ struct TestCase
|
||||
std::string_view expected_type;
|
||||
};
|
||||
|
||||
// Helper function to set the sub_matrix size.
|
||||
constexpr ckb::ThreadBlock set_submatrix(int m, int n, int k)
|
||||
{
|
||||
return {.block_size = 256, .submatrix = {.m = m, .n = n, .k = k}};
|
||||
}
|
||||
|
||||
// Test cases to drive the typed test suite.
|
||||
constexpr std::array TEST_CASES = {
|
||||
TestCase{
|
||||
// double rate mfma instances on gfx950
|
||||
.name = "ConvFwdXdlBf16CompInstances2x_0",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 128, .k = 64}},
|
||||
.tuning_params{.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V4,
|
||||
},
|
||||
{.thread_block = set_submatrix(256, 128, 64),
|
||||
.tuning_params = {.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V4},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 64, Default, 32, 32, "
|
||||
"2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>",
|
||||
@@ -69,17 +73,15 @@ constexpr std::array TEST_CASES = {
|
||||
// Compute-friendly.
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance0",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}},
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V4,
|
||||
},
|
||||
{.thread_block = set_submatrix(256, 256, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V4},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, "
|
||||
"4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>",
|
||||
@@ -87,17 +89,15 @@ constexpr std::array TEST_CASES = {
|
||||
TestCase{
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance1",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 128, .k = 64}},
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 8, .m = 32, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 8, .n = 32, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V4,
|
||||
},
|
||||
{.thread_block = set_submatrix(128, 128, 64),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 8, .m = 32, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 8, .n = 32, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V4},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, "
|
||||
"2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>",
|
||||
@@ -105,17 +105,15 @@ constexpr std::array TEST_CASES = {
|
||||
TestCase{
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance2",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 128, .k = 32}},
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V4,
|
||||
},
|
||||
{.thread_block = set_submatrix(128, 128, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V4},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Default, 32, 32, "
|
||||
"2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>",
|
||||
@@ -123,17 +121,15 @@ constexpr std::array TEST_CASES = {
|
||||
TestCase{
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance3",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}},
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V3,
|
||||
},
|
||||
{.thread_block = set_submatrix(256, 256, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V3},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, "
|
||||
"4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>",
|
||||
@@ -141,17 +137,15 @@ constexpr std::array TEST_CASES = {
|
||||
TestCase{
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance4",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}},
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V5,
|
||||
},
|
||||
{.thread_block = set_submatrix(256, 256, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V5},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, "
|
||||
"4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>",
|
||||
@@ -160,14 +154,14 @@ constexpr std::array TEST_CASES = {
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance5",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 128, .k = 32}},
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
},
|
||||
.thread_block = set_submatrix(256, 128, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V1,
|
||||
},
|
||||
.expected_type =
|
||||
@@ -178,14 +172,14 @@ constexpr std::array TEST_CASES = {
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance7",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 256, .k = 32}},
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
},
|
||||
.thread_block = set_submatrix(128, 256, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V1,
|
||||
},
|
||||
.expected_type =
|
||||
@@ -196,14 +190,14 @@ constexpr std::array TEST_CASES = {
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance8",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 128, .k = 64}},
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
},
|
||||
.thread_block = set_submatrix(128, 128, 64),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V1,
|
||||
},
|
||||
.expected_type =
|
||||
@@ -214,14 +208,14 @@ constexpr std::array TEST_CASES = {
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance9",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 128, .n = 64, .k = 64}},
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
},
|
||||
.thread_block = set_submatrix(128, 64, 64),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V3,
|
||||
},
|
||||
.expected_type =
|
||||
@@ -232,14 +226,14 @@ constexpr std::array TEST_CASES = {
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance9",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 64, .n = 128, .k = 64}},
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
},
|
||||
.thread_block = set_submatrix(64, 128, 64),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V3,
|
||||
},
|
||||
.expected_type =
|
||||
@@ -250,14 +244,14 @@ constexpr std::array TEST_CASES = {
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance9",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 64, .n = 64, .k = 32}},
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
},
|
||||
.thread_block = set_submatrix(64, 64, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V3,
|
||||
},
|
||||
.expected_type =
|
||||
|
||||
Reference in New Issue
Block a user