[CK] Fixed MPerBlock=32 build issue for MXFP4 GEMM decode (#2512)

* added MPerBlock=32 for MXFP4 GEMM decode

* added two instance for M>128 scenario.

* added 1 instance

* format

---------

Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
Mingtao Gu
2025-07-18 14:35:54 +08:00
committed by GitHub
parent f0a8c18017
commit 0198257d79
2 changed files with 186 additions and 65 deletions

View File

@@ -226,85 +226,197 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_total_stages = MRepeat;
constexpr auto num_total_stages = std::max(2, MRepeat);
// Group num_mfma_perstage num_ds_read_a_perstage
// since we want to reuse a local register buffer
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
if constexpr(num_total_stages > 2)
{
// Group num_mfma_perstage num_ds_read_a_perstage
// since we want to reuse a local register buffer
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
constexpr auto num_ds_read_a_mfma_perstage =
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
constexpr auto num_ds_read_a_mfma_perstage =
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
constexpr auto num_ds_read_a_prefetch_stages = 2;
constexpr auto num_ds_read_a_prefetch_stages = 2;
constexpr auto buffer_load_perstage_more =
math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
constexpr auto buffer_load_perstage_less =
math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
constexpr auto buffer_load_perstage_stage2 =
math::integer_divide_floor((num_buffer_load_stage2), 2);
constexpr auto buffer_load_perstage_more =
math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
constexpr auto buffer_load_perstage_less =
math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
constexpr auto buffer_load_perstage_stage2 =
math::integer_divide_floor((num_buffer_load_stage2), 2);
constexpr auto buffer_load_stages_more =
num_buffer_load_stage1 -
math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
((num_total_stages - 2));
constexpr auto buffer_load_stages_more =
num_buffer_load_stage1 -
math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
((num_total_stages - 2));
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less;
constexpr auto buffer_load_issue_point_interval_stage2 =
num_mfma_perstage / buffer_load_perstage_stage2;
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less;
constexpr auto buffer_load_issue_point_interval_stage2 =
num_mfma_perstage / buffer_load_perstage_stage2;
// Stage 1
// global read more
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// Stage 1
// global read more
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// global read less
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// Stage 2, Sync
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
}
else
{
constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
num_buffer_load_a_scale +
num_buffer_load_b_scale;
constexpr auto num_dsread_a_mfma = math::integer_divide_ceil(
num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a
// stage 1
constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
constexpr auto mfma_perstage_more =
math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
constexpr auto mfma_perstage_less =
math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
constexpr auto mfma_stages_more =
num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
if constexpr(i < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) <
mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
num_buffer_load_a_scale) < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// global read less
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
else
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
0); // DS read
}
});
});
// Stage 2, Sync
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
}
}
template <bool HasMainLoop,