mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
[CK_TILE] FMHA FAv3 scheduling fine-tuning for performance (#2833)
* Re-mapping thread block indices for causal=True kernels * Use more intuitive remap_opt value * Fallback to origin remapping if seqlen_q >= 64K * Use GenericAttentionMask to reduce mask computation * Avoid unnecessary boundary check for IsMasking=false case * Fix wrong kernel entry specifier * Add s_nop to prevent delay wave0-3 * Refine scheduling * Remove unnecessary sched_group_barrier() * Move sched_group_barrier() call to scheduler * Replace inline asm s_setprio with intrinsics * Rephrase comments * Expend some o_acc rescaling insts to avoid SIMD idle * Fix block idx special mapping logic * Tune block index mapping for causal=False cases * Tune block index mapping for causal=True cases * Fix wrong vmcnt() * Remove parameter name * Use boolean option for turn on/off causal mask * Update benchmark_fwd_v3.sh option usages * Add option if compiler support it
This commit is contained in:
@@ -57,7 +57,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 1) {}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
@@ -68,11 +72,19 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 3) {}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(Phase == 0) {}
|
||||
if constexpr(Phase == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
@@ -81,7 +93,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 2) {}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
@@ -115,7 +131,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 1) {}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
@@ -126,11 +146,19 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 3) {}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(Phase == 0) {}
|
||||
if constexpr(Phase == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
@@ -139,7 +167,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 2) {}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
@@ -177,6 +209,15 @@ CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
|
||||
return result;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
|
||||
{
|
||||
float result;
|
||||
asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]"
|
||||
: [result] "=v"(result)
|
||||
: [lhs] "v"(lhs), [rhs] "v"(rhs));
|
||||
return result;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
|
||||
{
|
||||
fp16x2_t result;
|
||||
@@ -466,7 +507,7 @@ struct BlockFmhaFwdV3Pipeline
|
||||
statically_indexed_array<sp_compute_type, 2> sp;
|
||||
|
||||
decltype(gemm_1.MakeCBlockTile()) o_acc;
|
||||
constexpr index_t fmha_alu_D_reg_cnt = 0; // threshold to decide how many fmha_alu_D_upd()
|
||||
constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd()
|
||||
// instructions should we move to fmha_alu1()
|
||||
static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
|
||||
|
||||
@@ -631,8 +672,8 @@ struct BlockFmhaFwdV3Pipeline
|
||||
|
||||
// K_mem_su_ld_insts = 1 for 32 x 128
|
||||
// V_mem_su_ld_insts = 1 for 128 x 32
|
||||
static constexpr int K_mem_su_ld_insts = 1;
|
||||
static constexpr int V_mem_su_ld_insts = 1;
|
||||
constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
|
||||
constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
|
||||
|
||||
auto K_mem_load = [&](auto k_lds_write_idx) {
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
@@ -648,7 +689,6 @@ struct BlockFmhaFwdV3Pipeline
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
/// FIXME: use the future-predicting method to move the window
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
@@ -726,11 +766,12 @@ struct BlockFmhaFwdV3Pipeline
|
||||
#else
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
#endif
|
||||
// update partial o_acc [0, 2)
|
||||
static_for<0, ck_tile::min(2, fmha_alu_D_reg_cnt), 1>{}(
|
||||
[&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
|
||||
|
||||
// l{j}
|
||||
/// Note: The compiler keeps moving the following instructions elsewhere because 'l'
|
||||
/// is first consumed later. To anchor them here, we rewrite the final addition in
|
||||
/// inline assembly to create a dependency, forcing the dependent instructions to
|
||||
/// be emitted at this point.
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
@@ -739,13 +780,15 @@ struct BlockFmhaFwdV3Pipeline
|
||||
l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]);
|
||||
});
|
||||
|
||||
// update partial o_acc [2, fmha_alu_D_reg_cnt)
|
||||
static_for<2, ck_tile::max(2, fmha_alu_D_reg_cnt), 1>{}(
|
||||
[&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
|
||||
// update partial o_acc [0, fmha_alu_D_reg_cnt)
|
||||
static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) {
|
||||
o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale);
|
||||
});
|
||||
|
||||
/// NOTICE: Compiler keep moving the conversion instructions to other places. We rewite
|
||||
/// the cast_tile() call into inline asm to force the conversion instructions to be
|
||||
/// generated here. The fmha_alu1() call should be placed at the end of a phase.
|
||||
/// Note: The compiler keeps sinking the conversion instructions because the
|
||||
/// result 'p' is only consumed later. To anchor them here, we rewrite
|
||||
/// the cast_tile() call as inline assembly, forcing the conversions to be
|
||||
/// emitted at this point.
|
||||
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
|
||||
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) {
|
||||
float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
|
||||
@@ -763,6 +806,10 @@ struct BlockFmhaFwdV3Pipeline
|
||||
sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
|
||||
}
|
||||
});
|
||||
|
||||
/// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly
|
||||
/// can interfere with the behavior of sched_group_barrier(), so ending the phase here
|
||||
/// avoids unintended reordering.
|
||||
};
|
||||
|
||||
auto gemm = [&](auto sp_reg_idx, auto gemm_idx) {
|
||||
@@ -937,9 +984,9 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
|
||||
Scheduler::schedule(cl_p, number<1>{});
|
||||
fmha_mask(xdl_SP_p01_reg_idx);
|
||||
|
||||
Scheduler::schedule(cl_p, number<1>{});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// phase2
|
||||
ASM_MARKER("phase2 Wave0-3");
|
||||
@@ -947,6 +994,8 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
asm volatile("s_nop 0");
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_calc(xdl_SP_p23_reg_idx, gemm1);
|
||||
|
||||
Scheduler::schedule(cl_p, number<2>{});
|
||||
@@ -995,6 +1044,8 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
asm volatile("s_nop 1");
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_calc(xdl_SP_p01_reg_idx, gemm0);
|
||||
fmha_alu1(xdl_SP_p23_reg_idx);
|
||||
|
||||
@@ -1005,9 +1056,9 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
|
||||
Scheduler::schedule(cl_p, number<2>{});
|
||||
fmha_mask(xdl_SP_p01_reg_idx);
|
||||
|
||||
Scheduler::schedule(cl_p, number<2>{});
|
||||
kv_token_start += kN0;
|
||||
if(num_total_loop <= ++i_total_loops)
|
||||
{
|
||||
@@ -1021,6 +1072,8 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
asm volatile("s_nop 1");
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_calc(xdl_SP_p23_reg_idx, gemm1);
|
||||
|
||||
Scheduler::schedule(cl_p, number<3>{});
|
||||
@@ -1036,7 +1089,14 @@ struct BlockFmhaFwdV3Pipeline
|
||||
auto ps_pi = number<1>{} - d;
|
||||
auto V_lds_rd_idx = ps_pi;
|
||||
|
||||
s_waitcnt_vmcnt<K_mem_su_ld_insts>();
|
||||
if(1 < num_total_loop)
|
||||
{
|
||||
s_waitcnt_vmcnt<K_mem_su_ld_insts>();
|
||||
}
|
||||
else
|
||||
{
|
||||
s_waitcnt_vmcnt<0>();
|
||||
}
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
V_lds_load(V_lds_rd_idx);
|
||||
@@ -1102,14 +1162,14 @@ struct BlockFmhaFwdV3Pipeline
|
||||
V_mem_load(number<1>{}); // V1
|
||||
K_lds_load(number<1>{}); // K1
|
||||
|
||||
asm volatile("s_setprio 0");
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
while(core_loop(number<0>{}))
|
||||
;
|
||||
}
|
||||
if(warp_group_id != 0)
|
||||
{
|
||||
asm volatile("s_setprio 1");
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
while(core_loop(number<1>{}))
|
||||
;
|
||||
@@ -1167,14 +1227,13 @@ struct BlockFmhaFwdV3Pipeline
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user