[CK_TILE] FMHA FAv3 scheduling fine-tuning for performance (#2833)

* Re-mapping thread block indices for causal=True kernels

* Use more intuitive remap_opt value

* Fallback to origin remapping if seqlen_q >= 64K

* Use GenericAttentionMask to reduce mask computation

* Avoid unnecessary boundary check for IsMasking=false case

* Fix wrong kernel entry specifier

* Add s_nop to prevent delay wave0-3

* Refine scheduling

* Remove unnecessary sched_group_barrier()

* Move sched_group_barrier() call to scheduler

* Replace inline asm s_setprio with intrinsics

* Rephrase comments

* Expend some o_acc rescaling insts to avoid SIMD idle

* Fix block idx special mapping logic

* Tune block index mapping for causal=False cases

* Tune block index mapping for causal=True cases

* Fix wrong vmcnt()

* Remove parameter name

* Use boolean option for turn on/off causal mask

* Update benchmark_fwd_v3.sh option usages

* Add option if compiler support it
This commit is contained in:
Po Yen Chen
2025-09-16 11:32:38 +08:00
committed by GitHub
parent 7d7ded62d3
commit 7fbc9d6c97
8 changed files with 250 additions and 112 deletions

View File

@@ -57,7 +57,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
});
}
else if constexpr(Phase == 1) {}
else if constexpr(Phase == 1)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
else if constexpr(Phase == 2)
{
#if !CK_TILE_DISABLE_PACKED_FP32
@@ -68,11 +72,19 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
});
}
else if constexpr(Phase == 3) {}
else if constexpr(Phase == 3)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
}
else
{
if constexpr(Phase == 0) {}
if constexpr(Phase == 0)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
else if constexpr(Phase == 1)
{
static_for<0, 8, 1>{}([&](auto) {
@@ -81,7 +93,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
});
}
else if constexpr(Phase == 2) {}
else if constexpr(Phase == 2)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
else if constexpr(Phase == 3)
{
#if !CK_TILE_DISABLE_PACKED_FP32
@@ -115,7 +131,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
});
}
else if constexpr(Phase == 1) {}
else if constexpr(Phase == 1)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
else if constexpr(Phase == 2)
{
#if !CK_TILE_DISABLE_PACKED_FP32
@@ -126,11 +146,19 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
});
}
else if constexpr(Phase == 3) {}
else if constexpr(Phase == 3)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
}
else
{
if constexpr(Phase == 0) {}
if constexpr(Phase == 0)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
else if constexpr(Phase == 1)
{
static_for<0, 8, 1>{}([&](auto) {
@@ -139,7 +167,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
});
}
else if constexpr(Phase == 2) {}
else if constexpr(Phase == 2)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
else if constexpr(Phase == 3)
{
#if !CK_TILE_DISABLE_PACKED_FP32
@@ -177,6 +209,15 @@ CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
return result;
}
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
{
float result;
asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]"
: [result] "=v"(result)
: [lhs] "v"(lhs), [rhs] "v"(rhs));
return result;
}
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
{
fp16x2_t result;
@@ -466,7 +507,7 @@ struct BlockFmhaFwdV3Pipeline
statically_indexed_array<sp_compute_type, 2> sp;
decltype(gemm_1.MakeCBlockTile()) o_acc;
constexpr index_t fmha_alu_D_reg_cnt = 0; // threshold to decide how many fmha_alu_D_upd()
constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd()
// instructions should we move to fmha_alu1()
static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
@@ -631,8 +672,8 @@ struct BlockFmhaFwdV3Pipeline
// K_mem_su_ld_insts = 1 for 32 x 128
// V_mem_su_ld_insts = 1 for 128 x 32
static constexpr int K_mem_su_ld_insts = 1;
static constexpr int V_mem_su_ld_insts = 1;
constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
auto K_mem_load = [&](auto k_lds_write_idx) {
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
@@ -648,7 +689,6 @@ struct BlockFmhaFwdV3Pipeline
auto V_mem_load = [&](auto v_lds_write_idx) {
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
__builtin_amdgcn_sched_barrier(0);
/// FIXME: use the future-predicting method to move the window
move_tile_window(v_dram_window, {kK1, 0});
@@ -726,11 +766,12 @@ struct BlockFmhaFwdV3Pipeline
#else
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
#endif
// update partial o_acc [0, 2)
static_for<0, ck_tile::min(2, fmha_alu_D_reg_cnt), 1>{}(
[&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
// l{j}
/// Note: The compiler keeps moving the following instructions elsewhere because 'l'
/// is first consumed later. To anchor them here, we rewrite the final addition in
/// inline assembly to create a dependency, forcing the dependent instructions to
/// be emitted at this point.
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
@@ -739,13 +780,15 @@ struct BlockFmhaFwdV3Pipeline
l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]);
});
// update partial o_acc [2, fmha_alu_D_reg_cnt)
static_for<2, ck_tile::max(2, fmha_alu_D_reg_cnt), 1>{}(
[&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
// update partial o_acc [0, fmha_alu_D_reg_cnt)
static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) {
o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale);
});
/// NOTICE: Compiler keep moving the conversion instructions to other places. We rewite
/// the cast_tile() call into inline asm to force the conversion instructions to be
/// generated here. The fmha_alu1() call should be placed at the end of a phase.
/// Note: The compiler keeps sinking the conversion instructions because the
/// result 'p' is only consumed later. To anchor them here, we rewrite
/// the cast_tile() call as inline assembly, forcing the conversions to be
/// emitted at this point.
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) {
float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
@@ -763,6 +806,10 @@ struct BlockFmhaFwdV3Pipeline
sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
}
});
/// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly
/// can interfere with the behavior of sched_group_barrier(), so ending the phase here
/// avoids unintended reordering.
};
auto gemm = [&](auto sp_reg_idx, auto gemm_idx) {
@@ -937,9 +984,9 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
Scheduler::schedule(cl_p, number<1>{});
fmha_mask(xdl_SP_p01_reg_idx);
Scheduler::schedule(cl_p, number<1>{});
__builtin_amdgcn_sched_barrier(0);
// phase2
ASM_MARKER("phase2 Wave0-3");
@@ -947,6 +994,8 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_nop 0");
__builtin_amdgcn_sched_barrier(0);
cl_calc(xdl_SP_p23_reg_idx, gemm1);
Scheduler::schedule(cl_p, number<2>{});
@@ -995,6 +1044,8 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_nop 1");
__builtin_amdgcn_sched_barrier(0);
cl_calc(xdl_SP_p01_reg_idx, gemm0);
fmha_alu1(xdl_SP_p23_reg_idx);
@@ -1005,9 +1056,9 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
Scheduler::schedule(cl_p, number<2>{});
fmha_mask(xdl_SP_p01_reg_idx);
Scheduler::schedule(cl_p, number<2>{});
kv_token_start += kN0;
if(num_total_loop <= ++i_total_loops)
{
@@ -1021,6 +1072,8 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_nop 1");
__builtin_amdgcn_sched_barrier(0);
cl_calc(xdl_SP_p23_reg_idx, gemm1);
Scheduler::schedule(cl_p, number<3>{});
@@ -1036,7 +1089,14 @@ struct BlockFmhaFwdV3Pipeline
auto ps_pi = number<1>{} - d;
auto V_lds_rd_idx = ps_pi;
s_waitcnt_vmcnt<K_mem_su_ld_insts>();
if(1 < num_total_loop)
{
s_waitcnt_vmcnt<K_mem_su_ld_insts>();
}
else
{
s_waitcnt_vmcnt<0>();
}
__builtin_amdgcn_s_barrier();
V_lds_load(V_lds_rd_idx);
@@ -1102,14 +1162,14 @@ struct BlockFmhaFwdV3Pipeline
V_mem_load(number<1>{}); // V1
K_lds_load(number<1>{}); // K1
asm volatile("s_setprio 0");
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
while(core_loop(number<0>{}))
;
}
if(warp_group_id != 0)
{
asm volatile("s_setprio 1");
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_s_barrier();
while(core_loop(number<1>{}))
;
@@ -1167,14 +1227,13 @@ struct BlockFmhaFwdV3Pipeline
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
float scale_s,
void* smem_ptr) const
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
float scale_s,
void* smem_ptr) const
{
using namespace ck_tile;