Initialize instruction schedule

This commit is contained in:
MHYang
2025-04-22 14:43:38 +00:00
committed by Philip Maybank
parent 879edeadf1
commit 0d8693776e
3 changed files with 107 additions and 0 deletions

View File

@@ -26,6 +26,100 @@ struct BlockGemmARegBSmemCRegV1
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kPackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MPerXDL = WG::kM;
constexpr index_t NPerXDL = WG::kN;
constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNumM = config.template get<1>();
constexpr index_t B_LDS_RW_Width = SmemPack;
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (kBlockSize * VectorSizeB);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock /
(kBlockSize * B_LDS_RW_Width);
constexpr index_t C_MFMA_Inst_Num =
MPerBlock * NPerBlock * KPerBlock / (kBlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
// B split schedule
constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_b_issue_cycle =
B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_b_mfma =
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// stage 1
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma);
constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
constexpr auto num_mfma_per_dswrite_b =
(num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1;
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008,
num_mfma_per_issue - num_mfma_per_dswrite_b *
num_dswrite_per_issue_b,
0); // MFMA
});
// stage 2
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
}
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
__device__ void operator()(CBlockTensor& c_block_tensor,

View File

@@ -134,6 +134,8 @@ struct BlockGemmPipelineAGmemBGmemCReg<
b_block_tile = load_tile(b_copy_dram_window);
}
__builtin_amdgcn_sched_barrier(0);
if constexpr(k_loops > 2)
{
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
@@ -158,6 +160,9 @@ struct BlockGemmPipelineAGmemBGmemCReg<
store_tile(b_copy_lds_window, b_block_tile);
b_block_tile = load_tile(b_copy_dram_window);
block_gemm.HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
});
}
@@ -276,6 +281,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
b_block_tile = load_tile(b_copy_dram_window);
}
__builtin_amdgcn_sched_barrier(0);
if constexpr(k_loops > 2)
{
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
@@ -293,6 +299,9 @@ struct BlockGemmPipelineAGmemBGmemCReg<
store_tile(b_copy_lds_window, b_block_tile);
b_block_tile = load_tile(b_copy_dram_window);
block_gemm.HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
});
}

View File

@@ -310,6 +310,7 @@ struct FlashAttentionFwdImpl
if constexpr(k1_loops > 1)
{
__builtin_amdgcn_sched_barrier(0);
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
block_sync_lds();
@@ -321,6 +322,9 @@ struct FlashAttentionFwdImpl
block_sync_lds();
store_tile(v_lds_window, v);
move_tile_window(v_dram_window, {0, kK1PerBlock});
gemm1.template HotLoopScheduler<8, 4>();
__builtin_amdgcn_sched_barrier(0);
});
}
// tail