mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK] Fixed MPerBlock=32 build issue for MXFP4 GEMM decode (#2512)
* added MPerBlock=32 for MXFP4 GEMM decode * added two instance for M>128 scenario. * added 1 instance * format --------- Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
@@ -226,85 +226,197 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
// constexpr auto num_dsread_a_mfma =
|
||||
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
|
||||
constexpr auto num_total_stages = MRepeat;
|
||||
constexpr auto num_total_stages = std::max(2, MRepeat);
|
||||
|
||||
// Group num_mfma_perstage num_ds_read_a_perstage
|
||||
// since we want to reuse a local register buffer
|
||||
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
|
||||
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
|
||||
if constexpr(num_total_stages > 2)
|
||||
{
|
||||
// Group num_mfma_perstage num_ds_read_a_perstage
|
||||
// since we want to reuse a local register buffer
|
||||
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
|
||||
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
|
||||
|
||||
constexpr auto num_ds_read_a_mfma_perstage =
|
||||
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
|
||||
constexpr auto num_ds_read_a_mfma_perstage =
|
||||
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
|
||||
|
||||
constexpr auto num_ds_read_a_prefetch_stages = 2;
|
||||
constexpr auto num_ds_read_a_prefetch_stages = 2;
|
||||
|
||||
constexpr auto buffer_load_perstage_more =
|
||||
math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_less =
|
||||
math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_stage2 =
|
||||
math::integer_divide_floor((num_buffer_load_stage2), 2);
|
||||
constexpr auto buffer_load_perstage_more =
|
||||
math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_less =
|
||||
math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_stage2 =
|
||||
math::integer_divide_floor((num_buffer_load_stage2), 2);
|
||||
|
||||
constexpr auto buffer_load_stages_more =
|
||||
num_buffer_load_stage1 -
|
||||
math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
|
||||
((num_total_stages - 2));
|
||||
constexpr auto buffer_load_stages_more =
|
||||
num_buffer_load_stage1 -
|
||||
math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
|
||||
((num_total_stages - 2));
|
||||
|
||||
constexpr auto buffer_load_issue_point_interval_more =
|
||||
num_mfma_perstage / buffer_load_perstage_more;
|
||||
constexpr auto buffer_load_issue_point_interval_less =
|
||||
num_mfma_perstage / buffer_load_perstage_less;
|
||||
constexpr auto buffer_load_issue_point_interval_stage2 =
|
||||
num_mfma_perstage / buffer_load_perstage_stage2;
|
||||
constexpr auto buffer_load_issue_point_interval_more =
|
||||
num_mfma_perstage / buffer_load_perstage_more;
|
||||
constexpr auto buffer_load_issue_point_interval_less =
|
||||
num_mfma_perstage / buffer_load_perstage_less;
|
||||
constexpr auto buffer_load_issue_point_interval_stage2 =
|
||||
num_mfma_perstage / buffer_load_perstage_stage2;
|
||||
|
||||
// Stage 1
|
||||
// global read more
|
||||
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
// Stage 1
|
||||
// global read more
|
||||
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// global read less
|
||||
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Stage 2, Sync
|
||||
// lds synchronization, prefetch next loop local A
|
||||
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
|
||||
num_buffer_load_a_scale +
|
||||
num_buffer_load_b_scale;
|
||||
constexpr auto num_dsread_a_mfma = math::integer_divide_ceil(
|
||||
num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a
|
||||
|
||||
// stage 1
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
|
||||
|
||||
constexpr auto mfma_perstage_more =
|
||||
math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
|
||||
constexpr auto mfma_perstage_less =
|
||||
math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
|
||||
|
||||
constexpr auto mfma_stages_more =
|
||||
num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
|
||||
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
if constexpr(i < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) <
|
||||
mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
|
||||
num_buffer_load_a_scale) < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// global read less
|
||||
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Stage 2, Sync
|
||||
// lds synchronization, prefetch next loop local A
|
||||
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
|
||||
Reference in New Issue
Block a user