mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
mxpf4 moe block_m 32
This commit is contained in:
@@ -226,8 +226,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
|
||||
// constexpr auto num_dsread_a_mfma =
|
||||
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
|
||||
constexpr auto num_total_stages = MRepeat;
|
||||
constexpr auto num_total_stages = std::max(2, MRepeat);
|
||||
|
||||
if constexpr(num_total_stages > 2)
|
||||
{
|
||||
// Group num_mfma_perstage num_ds_read_a_perstage
|
||||
// since we want to reuse a local register buffer
|
||||
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
|
||||
@@ -305,6 +307,113 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
|
||||
num_buffer_load_a_scale +
|
||||
num_buffer_load_b_scale;
|
||||
constexpr auto num_dsread_a_mfma = math::integer_divide_ceil(
|
||||
num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a
|
||||
|
||||
// stage 1
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
|
||||
|
||||
constexpr auto mfma_perstage_more =
|
||||
math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
|
||||
constexpr auto mfma_perstage_less =
|
||||
math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
|
||||
|
||||
constexpr auto mfma_stages_more =
|
||||
num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
|
||||
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
if constexpr(i < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) <
|
||||
mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
|
||||
num_buffer_load_a_scale) < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
|
||||
Reference in New Issue
Block a user