generalized bpreshuffle pipeline optimization

This commit is contained in:
aska-0096
2025-04-27 11:50:30 +00:00
parent 49338edb1b
commit bc9c819aa4
2 changed files with 148 additions and 11 deletions

View File

@@ -132,6 +132,7 @@ static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 2;
// TODO: Epilogue performance issue. AtomicAdd lose 15~20% performance compare with Set.
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;

View File

@@ -187,6 +187,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
__device__ static constexpr auto HotLoopScheduler()
{
#if 0
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
@@ -237,6 +238,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
constexpr auto ds_write_issue_point_stage2 = num_mfma_per_issue_less >= 3 ? 1 : 0;
static_for<0, num_mfma_inst, 1>{}([&](auto i) {
constexpr auto current_buffer_load_issue =
i < num_stage1_mfma
? (i / num_mfma_per_issue_more)
: (num_stage1_bufferloads + (i - num_stage1_mfma) / num_mfma_per_issue_less);
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// Group num_mfma_perstage num_ds_read_a_perstage
@@ -258,21 +263,152 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
// Hide B lds wr issue latency
if constexpr((((i < num_stage1_mfma) &&
(i % num_mfma_per_issue_more == ds_write_issue_point_stage1)) ||
((i >= num_stage1_mfma) &&
((i - num_stage1_mfma) % num_mfma_per_issue_less ==
ds_write_issue_point_stage2))) &&
(((i < num_stage1_mfma) &&
((i / num_mfma_per_issue_more) < num_ds_write_inst_a)) ||
((i >= num_stage1_mfma) &&
((i - num_stage1_mfma) / num_mfma_per_issue_less +
num_stage1_bufferloads) < num_ds_write_inst_a)))
// Hide A lds wr issue latency
if constexpr((current_buffer_load_issue >= num_buffer_load_inst_b) &&
((((i < num_stage1_mfma) &&
(i % num_mfma_per_issue_more == ds_write_issue_point_stage1)) ||
((i >= num_stage1_mfma) &&
((i - num_stage1_mfma) % num_mfma_per_issue_less ==
ds_write_issue_point_stage2))) &&
(((i < num_stage1_mfma) &&
((i / num_mfma_per_issue_more - num_buffer_load_inst_b) < num_ds_write_inst_a)) ||
((i >= num_stage1_mfma) &&
((i - num_stage1_mfma) / num_mfma_per_issue_less +
num_stage1_bufferloads - num_buffer_load_inst_b) < num_ds_write_inst_a))))
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
});
#elif 1
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
? HotLoopInstList::A_LDS_Read_Inst_Num
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_total_stages = MRepeat;
// Group num_mfma_perstage num_ds_read_a_perstage
// since we want to reuse a local register buffer
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
constexpr auto num_ds_read_a_mfma_perstage =
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
constexpr auto num_ds_read_a_prefetch_stages = 2;
constexpr auto buffer_load_perstage_more = math::integer_divide_ceil(
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
constexpr auto buffer_load_perstage_less = math::integer_divide_floor(
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
constexpr auto buffer_load_stages_more =
(num_buffer_load_inst_a + num_buffer_load_inst_b) -
math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b),
(num_total_stages - 2)) *
((num_total_stages - 2));
constexpr auto buffer_load_b_stages =
buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b
? num_buffer_load_inst_b / buffer_load_perstage_more
: (buffer_load_stages_more +
(num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) /
buffer_load_perstage_less);
constexpr auto buffer_load_a_stages =
num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages;
constexpr auto buffer_load_issue_point_b = 0;
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less;
constexpr auto ds_write_issue_point = 0;
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
// B global read
static_for<0, buffer_load_b_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(((i < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_b)) ||
((i >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_b)))
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// A global read + A local write
static_for<0, buffer_load_a_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
ds_write_issue_point)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
ds_write_issue_point)))
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_a)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_a)))
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
#endif
}
template <bool HasMainLoop,