mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
reduce prefetch stage in blockwisepipev4
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
namespace ck {
|
||||
|
||||
// Compute optimimal pipeline with highest resource request
|
||||
// GlobalPrefetchStages: 4
|
||||
// GlobalPrefetchStages: 3
|
||||
// LocalPreFillStages: 2
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 2
|
||||
@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
static constexpr index_t PrefetchStages = 4;
|
||||
static constexpr index_t PrefetchStages = 3;
|
||||
static constexpr index_t PrefillStages = 2;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr index_t HotloopUnroll = 2;
|
||||
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
@@ -164,8 +164,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ScheduleGroup>
|
||||
__device__ static constexpr void HotLoopScheduler(ScheduleGroup schedule_group)
|
||||
__device__ static constexpr void HotLoopScheduler()
|
||||
{
|
||||
// TODO: Take data type into consideration as pipe ver 3
|
||||
// A-B splited schedule
|
||||
@@ -195,42 +194,42 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
ignore = i;
|
||||
static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) {
|
||||
ignore = idsread;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
|
||||
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008,
|
||||
num_mfma_per_issue - num_dsread_per_issue_a -
|
||||
num_dswrite_per_issue_a,
|
||||
schedule_group); // MFMA
|
||||
0); // MFMA
|
||||
});
|
||||
|
||||
static_for<0, num_issue_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) {
|
||||
ignore = idsread;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008,
|
||||
num_mfma_per_issue - num_dsread_per_issue_a -
|
||||
num_dswrite_per_issue_b,
|
||||
schedule_group); // MFMA
|
||||
0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
@@ -274,26 +273,15 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0), I0);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0), I0);
|
||||
|
||||
// Local prefill 2
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1), I1);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1), I1);
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
|
||||
|
||||
// Local prefetch 1
|
||||
block_sync_lds();
|
||||
@@ -316,16 +304,20 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
});
|
||||
|
||||
// Global prefetch 3
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Global prefetch 4
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
|
||||
// Local prefill 2
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
|
||||
|
||||
// Global prefetch 3
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
@@ -343,9 +335,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
auto LoopFunc = [&](auto lds_read_buf,
|
||||
auto lds_read_reg_buf,
|
||||
auto lds_write_buf,
|
||||
auto vmem_buf,
|
||||
auto mfma_reg_buf,
|
||||
auto schedule_group) {
|
||||
auto mfma_reg_buf) {
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
@@ -368,13 +358,11 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
});
|
||||
|
||||
a_blockwise_copy.RunWrite(
|
||||
a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
|
||||
b_blockwise_copy.RunWrite(
|
||||
b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf);
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
@@ -411,11 +399,11 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler(schedule_group);
|
||||
HotLoopScheduler();
|
||||
};
|
||||
|
||||
LoopFunc(I1, I1, I0, I0, I0, I0);
|
||||
LoopFunc(I0, I0, I1, I1, I1, I0);
|
||||
LoopFunc(I1, I1, I0, I0);
|
||||
LoopFunc(I0, I0, I1, I1);
|
||||
|
||||
i += HotloopUnroll;
|
||||
} while(i < (num_loop - PrefetchStages));
|
||||
@@ -424,9 +412,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
auto ReadWriteCompFunc = [&](auto lds_read_buf,
|
||||
auto lds_read_reg_buf,
|
||||
auto lds_write_buf,
|
||||
auto vmem_buf,
|
||||
auto mfma_reg_buf,
|
||||
auto schedule_group) {
|
||||
auto mfma_reg_buf) {
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
@@ -448,8 +434,8 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
@@ -479,13 +465,10 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler(schedule_group);
|
||||
HotLoopScheduler();
|
||||
};
|
||||
|
||||
auto ReadCompFunc = [&](auto lds_read_buf,
|
||||
auto lds_read_reg_buf,
|
||||
auto mfma_reg_buf,
|
||||
auto schedule_group) {
|
||||
auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) {
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
@@ -535,7 +518,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler(schedule_group);
|
||||
HotLoopScheduler();
|
||||
};
|
||||
|
||||
auto CompFunc = [&](auto mfma_reg_buf) {
|
||||
@@ -570,15 +553,13 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
ReadWriteCompFunc(I1, I1, I0, I0, I0, I1);
|
||||
ReadCompFunc(I0, I0, I1, I1);
|
||||
ReadWriteCompFunc(I1, I1, I0, I0);
|
||||
ReadCompFunc(I0, I0, I1);
|
||||
CompFunc(I0);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
ReadWriteCompFunc(I1, I1, I0, I0, I0, I1);
|
||||
ReadWriteCompFunc(I0, I0, I1, I1, I1, I1);
|
||||
ReadCompFunc(I1, I1, I0, I1);
|
||||
ReadCompFunc(I1, I1, I0);
|
||||
CompFunc(I1);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user