mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 20:27:42 +00:00
Use set_thread_cluster_dims helper.
This simplifies our instantiation test.
This commit is contained in:
@@ -122,19 +122,19 @@ static_assert(BlockCTransferLengths<BlockCTransferLengthsInfo>);
|
||||
// Concept to check if a struct provides A Block tranfer info.
|
||||
template <typename T>
|
||||
concept HasABlockTransferInfo = requires(T t) {
|
||||
{ T::block_transfer.thread_cluster_lengths_a } -> BlockATransferLengths;
|
||||
{ T::block_transfer.thread_cluster_dims_a } -> BlockATransferLengths;
|
||||
};
|
||||
|
||||
// Concept to check if a struct provides B Block tranfer info.
|
||||
template <typename T>
|
||||
concept HasBBlockTransferInfo = requires(T t) {
|
||||
{ T::block_transfer.thread_cluster_lengths_b } -> BlockBTransferLengths;
|
||||
{ T::block_transfer.thread_cluster_dims_b } -> BlockBTransferLengths;
|
||||
};
|
||||
|
||||
// Concept to check if a struct provides C Block tranfer info.
|
||||
template <typename T>
|
||||
concept HasCBlockTransferInfo = requires(T t) {
|
||||
{ T::block_transfer.thread_cluster_lengths_c } -> BlockCTransferLengths;
|
||||
{ T::block_transfer.thread_cluster_dims_c } -> BlockCTransferLengths;
|
||||
};
|
||||
|
||||
enum class BlockGemmPipelineVersion
|
||||
|
||||
@@ -132,29 +132,29 @@ constexpr ConvTuning SetConvTuningInfo()
|
||||
// Block transfer paramters for A or B tensor.
|
||||
struct BlockTransfer
|
||||
{
|
||||
ck::Array<int, 3> thread_cluster_lengths = {0, 0, 0}; // k0, m, k1
|
||||
ck::Array<int, 3> thread_cluster_order = {0, 0, 0};
|
||||
ck::Array<int, 3> src_access_order = {0, 0, 0};
|
||||
int src_vector_dim = 0;
|
||||
int src_scaler_per_vector = 0;
|
||||
int dest_scaler_per_vector_k1 = 0;
|
||||
int add_extra = 0;
|
||||
ck::Array<int, 3> thread_cluster_dims = {0, 0, 0}; // k0, m, k1
|
||||
ck::Array<int, 3> thread_cluster_order = {0, 0, 0};
|
||||
ck::Array<int, 3> src_access_order = {0, 0, 0};
|
||||
int src_vector_dim = 0;
|
||||
int src_scaler_per_vector = 0;
|
||||
int dest_scaler_per_vector_k1 = 0;
|
||||
int add_extra = 0;
|
||||
};
|
||||
|
||||
// Block transfer parameters for C tensor.
|
||||
struct CBlockTransfer
|
||||
{
|
||||
int m_xdl_per_wave_per_shuffle = 0;
|
||||
int n_xdl_per_wave_per_shuffle = 0;
|
||||
ck::Array<int, 4> thread_cluster_lengths = {0, 0, 0, 0};
|
||||
int scaler_per_vector = 8;
|
||||
int m_xdl_per_wave_per_shuffle = 0;
|
||||
int n_xdl_per_wave_per_shuffle = 0;
|
||||
ck::Array<int, 4> thread_cluster_dims = {0, 0, 0, 0};
|
||||
int scaler_per_vector = 8;
|
||||
};
|
||||
|
||||
template <ConvAlgorithm auto ALGORITHM>
|
||||
constexpr BlockTransfer SetABlockTransfer()
|
||||
{
|
||||
BlockTransfer block_transfer{
|
||||
.thread_cluster_lengths = {4, 64, 1},
|
||||
.thread_cluster_dims = {4, 64, 1},
|
||||
.thread_cluster_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.src_vector_dim = 2,
|
||||
@@ -165,8 +165,8 @@ constexpr BlockTransfer SetABlockTransfer()
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(HasABlockTransferInfo<AlgorithmType>)
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_lengths_a;
|
||||
block_transfer.thread_cluster_lengths = {TCL.k0, TCL.m, TCL.k1};
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_a;
|
||||
block_transfer.thread_cluster_dims = {TCL.k0, TCL.m, TCL.k1};
|
||||
}
|
||||
// Default.
|
||||
return block_transfer;
|
||||
@@ -176,7 +176,7 @@ template <ConvAlgorithm auto ALGORITHM>
|
||||
constexpr BlockTransfer SetBBlockTransfer()
|
||||
{
|
||||
BlockTransfer block_transfer{
|
||||
.thread_cluster_lengths = {4, 64, 1},
|
||||
.thread_cluster_dims = {4, 64, 1},
|
||||
.thread_cluster_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.src_vector_dim = 2,
|
||||
@@ -187,8 +187,8 @@ constexpr BlockTransfer SetBBlockTransfer()
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(HasBBlockTransferInfo<AlgorithmType>)
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_lengths_b;
|
||||
block_transfer.thread_cluster_lengths = {TCL.k0, TCL.n, TCL.k1};
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_b;
|
||||
block_transfer.thread_cluster_dims = {TCL.k0, TCL.n, TCL.k1};
|
||||
}
|
||||
// Default.
|
||||
return block_transfer;
|
||||
@@ -200,14 +200,14 @@ constexpr CBlockTransfer SetCBlockTransfer()
|
||||
CBlockTransfer block_transfer{
|
||||
.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.thread_cluster_lengths = {1, 32, 1, 8},
|
||||
.thread_cluster_dims = {1, 32, 1, 8},
|
||||
.scaler_per_vector = 8,
|
||||
};
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(HasCBlockTransferInfo<AlgorithmType>)
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_lengths_c;
|
||||
block_transfer.thread_cluster_lengths = {
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_c;
|
||||
block_transfer.thread_cluster_dims = {
|
||||
TCL.m_block,
|
||||
TCL.m_wave_per_xdl,
|
||||
TCL.n_block,
|
||||
@@ -284,14 +284,14 @@ struct GroupedConvForwardXldCShuffleFactoryV3
|
||||
TUNING.n_per_dxl,
|
||||
TUNING.m_xdl_per_wave,
|
||||
TUNING.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_lengths>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scaler_per_vector,
|
||||
A_BLOCK_TRANSFER.dest_scaler_per_vector_k1,
|
||||
A_BLOCK_TRANSFER.add_extra,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_lengths>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
@@ -300,7 +300,7 @@ struct GroupedConvForwardXldCShuffleFactoryV3
|
||||
B_BLOCK_TRANSFER.add_extra,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_lengths>,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scaler_per_vector,
|
||||
PIPELINE_SCHEDULER,
|
||||
PIPELINE_VERSION>;
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
namespace {
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
using P = ckb::BlockGemmPipelineVersion;
|
||||
using P = ckb::BlockGemmPipelineVersion;
|
||||
struct FwdConvSignature
|
||||
{
|
||||
static constexpr int spatial_dim = 2;
|
||||
@@ -24,9 +24,9 @@ struct FwdConvAlgorithm
|
||||
ckb::ConvTuningParams tuning_params;
|
||||
struct BlockTransfer
|
||||
{
|
||||
ckb::BlockATransferLengthsInfo thread_cluster_lengths_a;
|
||||
ckb::BlockBTransferLengthsInfo thread_cluster_lengths_b;
|
||||
ckb::BlockCTransferLengthsInfo thread_cluster_lengths_c;
|
||||
ckb::BlockATransferLengthsInfo thread_cluster_dims_a;
|
||||
ckb::BlockBTransferLengthsInfo thread_cluster_dims_b;
|
||||
ckb::BlockCTransferLengthsInfo thread_cluster_dims_c;
|
||||
} block_transfer;
|
||||
ckb::BlockGemmPipelineVersion pipeline_version;
|
||||
};
|
||||
@@ -51,6 +51,15 @@ constexpr ckb::ThreadBlock set_submatrix(int m, int n, int k)
|
||||
return {.block_size = 256, .submatrix = {.m = m, .n = n, .k = k}};
|
||||
}
|
||||
|
||||
// Helper function to set the thread cluster dimensions.
|
||||
constexpr FwdConvAlgorithm::BlockTransfer set_thread_cluster_dims(int k0, int m, int k1)
|
||||
{
|
||||
return {.thread_cluster_dims_a = {.k0 = k0, .m = m, .k1 = k1},
|
||||
.thread_cluster_dims_b = {.k0 = k0, .n = m, .k1 = k1},
|
||||
.thread_cluster_dims_c = {
|
||||
.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}};
|
||||
}
|
||||
|
||||
// Test cases to drive the typed test suite.
|
||||
constexpr std::array TEST_CASES = {
|
||||
TestCase{
|
||||
@@ -59,12 +68,7 @@ constexpr std::array TEST_CASES = {
|
||||
.algorithm =
|
||||
{.thread_block = set_submatrix(256, 128, 64),
|
||||
.tuning_params = {.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.block_transfer = set_thread_cluster_dims(4, 64, 1),
|
||||
.pipeline_version = P::V4},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 64, Default, 32, 32, "
|
||||
@@ -76,12 +80,7 @@ constexpr std::array TEST_CASES = {
|
||||
.algorithm =
|
||||
{.thread_block = set_submatrix(256, 256, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.block_transfer = set_thread_cluster_dims(4, 64, 1),
|
||||
.pipeline_version = P::V4},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, "
|
||||
@@ -92,12 +91,7 @@ constexpr std::array TEST_CASES = {
|
||||
.algorithm =
|
||||
{.thread_block = set_submatrix(128, 128, 64),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 8, .m = 32, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 8, .n = 32, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.block_transfer = set_thread_cluster_dims(8, 32, 1),
|
||||
.pipeline_version = P::V4},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, "
|
||||
@@ -108,12 +102,7 @@ constexpr std::array TEST_CASES = {
|
||||
.algorithm =
|
||||
{.thread_block = set_submatrix(128, 128, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.block_transfer = set_thread_cluster_dims(4, 64, 1),
|
||||
.pipeline_version = P::V4},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Default, 32, 32, "
|
||||
@@ -124,12 +113,7 @@ constexpr std::array TEST_CASES = {
|
||||
.algorithm =
|
||||
{.thread_block = set_submatrix(256, 256, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.block_transfer = set_thread_cluster_dims(4, 64, 1),
|
||||
.pipeline_version = P::V3},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, "
|
||||
@@ -140,12 +124,7 @@ constexpr std::array TEST_CASES = {
|
||||
.algorithm =
|
||||
{.thread_block = set_submatrix(256, 256, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.block_transfer = set_thread_cluster_dims(4, 64, 1),
|
||||
.pipeline_version = P::V5},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, "
|
||||
@@ -154,17 +133,10 @@ constexpr std::array TEST_CASES = {
|
||||
TestCase{
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance5",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block = set_submatrix(256, 128, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = P::V1,
|
||||
},
|
||||
{.thread_block = set_submatrix(256, 128, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = set_thread_cluster_dims(4, 64, 1),
|
||||
.pipeline_version = P::V1},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 32, Default, 32, 32, "
|
||||
"2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>",
|
||||
@@ -172,17 +144,10 @@ constexpr std::array TEST_CASES = {
|
||||
TestCase{
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance7",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block = set_submatrix(128, 256, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = P::V1,
|
||||
},
|
||||
{.thread_block = set_submatrix(128, 256, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = set_thread_cluster_dims(4, 64, 1),
|
||||
.pipeline_version = P::V1},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, "
|
||||
"2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>",
|
||||
@@ -190,17 +155,10 @@ constexpr std::array TEST_CASES = {
|
||||
TestCase{
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance8",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block = set_submatrix(128, 128, 64),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = P::V1,
|
||||
},
|
||||
{.thread_block = set_submatrix(128, 128, 64),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = set_thread_cluster_dims(4, 64, 1),
|
||||
.pipeline_version = P::V1},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, "
|
||||
"2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>",
|
||||
@@ -208,17 +166,10 @@ constexpr std::array TEST_CASES = {
|
||||
TestCase{
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance9",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block = set_submatrix(128, 64, 64),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = P::V3,
|
||||
},
|
||||
{.thread_block = set_submatrix(128, 64, 64),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = set_thread_cluster_dims(4, 64, 1),
|
||||
.pipeline_version = P::V3},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 64, 64, Default, 32, 32, 2, "
|
||||
"4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>",
|
||||
@@ -226,39 +177,14 @@ constexpr std::array TEST_CASES = {
|
||||
TestCase{
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance9",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block = set_submatrix(64, 128, 64),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = P::V3,
|
||||
},
|
||||
{.thread_block = set_submatrix(64, 128, 64),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = set_thread_cluster_dims(4, 64, 1),
|
||||
.pipeline_version = P::V3},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 128, 64, Default, 32, 32, 2, "
|
||||
"4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>",
|
||||
},
|
||||
TestCase{
|
||||
.name = "GroupedConvFwdXdlBf16CompInstance9",
|
||||
.algorithm =
|
||||
{
|
||||
.thread_block = set_submatrix(64, 64, 32),
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 2, .n_xdl_per_wave = 4},
|
||||
.block_transfer = {.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8}},
|
||||
.pipeline_version = P::V3,
|
||||
},
|
||||
.expected_type =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 64, 64, 32, Default, 32, 32, 2, "
|
||||
"4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>",
|
||||
},
|
||||
};
|
||||
|
||||
static constexpr int NUM_TEST_CASES = std::tuple_size_v<decltype(TEST_CASES)>;
|
||||
@@ -316,19 +242,19 @@ TYPED_TEST(ConvBuilderInstancesTest, KernelParamsConfigured)
|
||||
const auto& tp = ALGORITHM.tuning_params;
|
||||
EXPECT_EQ(Builder::factory::TUNING.ak1, tp.ak1);
|
||||
EXPECT_EQ(Builder::factory::TUNING.bk1, tp.bk1);
|
||||
const auto& tcla = ALGORITHM.block_transfer.thread_cluster_lengths_a;
|
||||
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[0], tcla.k0);
|
||||
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[1], tcla.m);
|
||||
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[2], tcla.k1);
|
||||
const auto& tclb = ALGORITHM.block_transfer.thread_cluster_lengths_b;
|
||||
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[0], tclb.k0);
|
||||
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[1], tclb.n);
|
||||
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[2], tclb.k1);
|
||||
const auto& tclc = ALGORITHM.block_transfer.thread_cluster_lengths_c;
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[0], tclc.m_block);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[1], tclc.m_wave_per_xdl);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[2], tclc.n_block);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[3], tclc.n_wave_per_xdl);
|
||||
const auto& tcda = ALGORITHM.block_transfer.thread_cluster_dims_a;
|
||||
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_dims[0], tcda.k0);
|
||||
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_dims[1], tcda.m);
|
||||
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_dims[2], tcda.k1);
|
||||
const auto& tcdb = ALGORITHM.block_transfer.thread_cluster_dims_b;
|
||||
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_dims[0], tcdb.k0);
|
||||
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_dims[1], tcdb.n);
|
||||
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_dims[2], tcdb.k1);
|
||||
const auto& tcdc = ALGORITHM.block_transfer.thread_cluster_dims_c;
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_dims[0], tcdc.m_block);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_dims[1], tcdc.m_wave_per_xdl);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_dims[2], tcdc.n_block);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_dims[3], tcdc.n_wave_per_xdl);
|
||||
}
|
||||
|
||||
TEST(ConvBuilderInstancesTest, TypeStringsAreUnique)
|
||||
|
||||
Reference in New Issue
Block a user