mxpf4 moe block_m 32

This commit is contained in:
xudoyuan
2025-10-24 16:54:26 +08:00
parent 6bbc05e1bd
commit 115ba5ece4

View File

@@ -226,8 +226,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_total_stages = MRepeat;
constexpr auto num_total_stages = std::max(2, MRepeat);
if constexpr(num_total_stages > 2)
{
// Group num_mfma_perstage num_ds_read_a_perstage
// since we want to reuse a local register buffer
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
@@ -305,6 +307,113 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
}
});
});
}
else
{
constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
num_buffer_load_a_scale +
num_buffer_load_b_scale;
constexpr auto num_dsread_a_mfma = math::integer_divide_ceil(
num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a
// stage 1
constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
constexpr auto mfma_perstage_more =
math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
constexpr auto mfma_perstage_less =
math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
constexpr auto mfma_stages_more =
num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
if constexpr(i < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) <
mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
num_buffer_load_a_scale) < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
0); // DS read
}
});
}
}
template <bool HasMainLoop,