mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
clean the code
This commit is contained in:
@@ -141,10 +141,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
|
||||
// clang-format off
|
||||
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
|
||||
224, 256, 128,
|
||||
256, 256, 128,
|
||||
16, 16,
|
||||
32, 32,
|
||||
7, 2,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
|
||||
|
||||
@@ -227,26 +227,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
}
|
||||
else if constexpr(stage.value == 1)
|
||||
{
|
||||
#if 0
|
||||
constexpr auto staged_num_ds_write_a_per_ds_read_a =
|
||||
num_ds_write_inst_a / staged_num_ds_read_inst_a;
|
||||
constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a;
|
||||
// A local write
|
||||
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
|
||||
ignore = i_inst;
|
||||
|
||||
static_for<0, staged_num_ds_write_a_per_ds_read_a, 1>{}([&](auto idswrite_inst) {
|
||||
ignore = idswrite_inst;
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
#elif 1
|
||||
constexpr auto staged_num_mfma_per_ds_write_a =
|
||||
math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
|
||||
|
||||
@@ -290,33 +270,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
}
|
||||
}
|
||||
});
|
||||
#endif
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(stage.value == 2)
|
||||
{
|
||||
#if 0
|
||||
constexpr auto staged_num_buffer_load_a_per_ds_read_a =
|
||||
num_buffer_load_inst_a / staged_num_ds_read_inst_a;
|
||||
constexpr auto staged_num_mfma_per_buffer_load_a =
|
||||
staged_num_mfma / num_buffer_load_inst_a;
|
||||
// A global
|
||||
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
|
||||
ignore = i_inst;
|
||||
static_for<0, staged_num_buffer_load_a_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) {
|
||||
ignore = ibuf_inst;
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
#elif 1
|
||||
constexpr auto staged_num_mfma_per_buffer_load_a =
|
||||
math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a);
|
||||
|
||||
@@ -360,7 +318,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
}
|
||||
}
|
||||
});
|
||||
#endif
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else
|
||||
|
||||
@@ -625,7 +625,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}};
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmXdlUniversal"
|
||||
|
||||
Reference in New Issue
Block a user