mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 05:37:34 +00:00
Initialize instruction schedule
This commit is contained in:
@@ -26,6 +26,100 @@ struct BlockGemmARegBSmemCRegV1
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
|
||||
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
constexpr index_t MPerXDL = WG::kM;
|
||||
constexpr index_t NPerXDL = WG::kN;
|
||||
constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNumM = config.template get<1>();
|
||||
|
||||
constexpr index_t B_LDS_RW_Width = SmemPack;
|
||||
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (kBlockSize * VectorSizeB);
|
||||
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
|
||||
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock /
|
||||
(kBlockSize * B_LDS_RW_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num =
|
||||
MPerBlock * NPerBlock * KPerBlock / (kBlockSize / WaveSize) /
|
||||
(MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
// B split schedule
|
||||
constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_b_mfma =
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
// stage 1
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
constexpr auto num_mfma_per_dswrite_b =
|
||||
(num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1;
|
||||
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008,
|
||||
num_mfma_per_issue - num_mfma_per_dswrite_b *
|
||||
num_dswrite_per_issue_b,
|
||||
0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
|
||||
ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
__device__ void operator()(CBlockTensor& c_block_tensor,
|
||||
|
||||
@@ -134,6 +134,8 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
@@ -158,6 +160,9 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
block_gemm.HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -276,6 +281,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
@@ -293,6 +299,9 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
block_gemm.HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -310,6 +310,7 @@ struct FlashAttentionFwdImpl
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
block_sync_lds();
|
||||
@@ -321,6 +322,9 @@ struct FlashAttentionFwdImpl
|
||||
block_sync_lds();
|
||||
store_tile(v_lds_window, v);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
|
||||
gemm1.template HotLoopScheduler<8, 4>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
// tail
|
||||
|
||||
Reference in New Issue
Block a user