[CK_TILE] FMHA FAv3 scheduling fine-tuning for performance (#2833)

* Re-mapping thread block indices for causal=True kernels

* Use more intuitive remap_opt value

* Fallback to origin remapping if seqlen_q >= 64K

* Use GenericAttentionMask to reduce mask computation

* Avoid unnecessary boundary check for IsMasking=false case

* Fix wrong kernel entry specifier

* Add s_nop to prevent delay wave0-3

* Refine scheduling

* Remove unnecessary sched_group_barrier()

* Move sched_group_barrier() call to scheduler

* Replace inline asm s_setprio with intrinsics

* Rephrase comments

* Expend some o_acc rescaling insts to avoid SIMD idle

* Fix block idx special mapping logic

* Tune block index mapping for causal=False cases

* Tune block index mapping for causal=True cases

* Fix wrong vmcnt()

* Remove parameter name

* Use boolean option for turn on/off causal mask

* Update benchmark_fwd_v3.sh option usages

* Add option if compiler support it
This commit is contained in:
Po Yen Chen
2025-09-16 11:32:38 +08:00
committed by GitHub
parent 7d7ded62d3
commit 7fbc9d6c97
8 changed files with 250 additions and 112 deletions

View File

@@ -203,27 +203,36 @@ struct GenericAttentionMask
CK_TILE_HOST_DEVICE constexpr auto
IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number<TileHeight>, number<TileWidth>) const
{
if constexpr(IsLocal)
if constexpr(!IsMasking)
{
// check top-right corner > x or left-borrom corner < x
index_t i_tile_right = i_tile_left + TileWidth;
index_t i_tile_bottom = i_tile_top + TileHeight;
index_t x_end = min(i_tile_top + x, x_total);
bool top_right_edge = i_tile_right > (i_tile_top + x);
bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
bool is_partial_out_of_bound = i_tile_right > x_end; // only consider right-pad for now
return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
// TODO: no need to check begin
return (i_tile_left + TileWidth) > x_total;
}
else
{
// only need to check top-right corner > x
index_t i_tile_right = i_tile_left + TileWidth;
index_t x_end = min(i_tile_top + x, x_total);
if constexpr(IsLocal)
{
// check top-right corner > x or left-borrom corner < x
index_t i_tile_right = i_tile_left + TileWidth;
index_t i_tile_bottom = i_tile_top + TileHeight;
index_t x_end = min(i_tile_top + x, x_total);
bool top_right_edge = i_tile_right > x_end;
return top_right_edge;
bool top_right_edge = i_tile_right > (i_tile_top + x);
bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
bool is_partial_out_of_bound =
i_tile_right > x_end; // only consider right-pad for now
return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
}
else
{
// only need to check top-right corner > x
index_t i_tile_right = i_tile_left + TileWidth;
index_t x_end = min(i_tile_top + x, x_total);
bool top_right_edge = i_tile_right > x_end;
return top_right_edge;
}
}
}

View File

@@ -81,6 +81,7 @@ struct FmhaFwdV3Kernel
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::GenericAttentionMaskEnum mask_type;
ck_tile::index_t remap_opt;
};
struct FmhaFwdCommonLSEKargs
@@ -143,7 +144,8 @@ struct FmhaFwdV3Kernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -176,6 +178,7 @@ struct FmhaFwdV3Kernel
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
kargs.remap_opt = remap_opt;
}
if constexpr(kStoreLSE)
{
@@ -213,7 +216,8 @@ struct FmhaFwdV3Kernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -245,6 +249,7 @@ struct FmhaFwdV3Kernel
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
kargs.remap_opt = remap_opt;
}
if constexpr(kStoreLSE)
{
@@ -261,39 +266,81 @@ struct FmhaFwdV3Kernel
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
nhead_,
batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
{
using namespace ck_tile;
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
return dim3(nhead_,
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
batch_size_);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
return dim3(nhead_,
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
batch_size_);
}
}
CK_TILE_DEVICE static constexpr auto
RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
{
if(remap_option < 1)
{
return make_tuple(static_cast<int32_t>(gridDim.x - tg_idx - 1), tg_idy);
}
int32_t remapped_tg_idx = tg_idx;
int32_t remapped_tg_idy = tg_idy;
if(remap_option == 2)
{ // special remapping
int32_t tmp0 = (remapped_tg_idy & 0x7) * gridDim.x + remapped_tg_idx;
int32_t tmp1 = tmp0 & 0x7;
remapped_tg_idx = tmp0 >> 3;
remapped_tg_idy = (remapped_tg_idy & 0xfffffff8) + tmp1;
}
else
{ // normal remapping
int32_t cus_per_xdim_per_xcc = gridDim.x >> 3;
int32_t tgs_cu_id = remapped_tg_idx >> 3;
if(tgs_cu_id < cus_per_xdim_per_xcc)
{
int32_t tgs_xcc_id = remapped_tg_idx & 0x7;
int32_t new_tg_idx = tgs_xcc_id * cus_per_xdim_per_xcc + tgs_cu_id;
remapped_tg_idx = new_tg_idx;
}
}
return make_tuple(remapped_tg_idx, remapped_tg_idy);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs&)
{
using namespace ck_tile;
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
// FmhaPipeline::kN1);
// assume that num_tile_n1 is always 1
if constexpr(kHasMask)
{
const index_t i_nhead = blockIdx.x;
const index_t i_block = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
}
else
{
const index_t i_nhead = blockIdx.x;
const index_t i_block = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
}
}

View File

@@ -57,7 +57,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
});
}
else if constexpr(Phase == 1) {}
else if constexpr(Phase == 1)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
else if constexpr(Phase == 2)
{
#if !CK_TILE_DISABLE_PACKED_FP32
@@ -68,11 +72,19 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
});
}
else if constexpr(Phase == 3) {}
else if constexpr(Phase == 3)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
}
else
{
if constexpr(Phase == 0) {}
if constexpr(Phase == 0)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
else if constexpr(Phase == 1)
{
static_for<0, 8, 1>{}([&](auto) {
@@ -81,7 +93,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
});
}
else if constexpr(Phase == 2) {}
else if constexpr(Phase == 2)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
else if constexpr(Phase == 3)
{
#if !CK_TILE_DISABLE_PACKED_FP32
@@ -115,7 +131,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
});
}
else if constexpr(Phase == 1) {}
else if constexpr(Phase == 1)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
else if constexpr(Phase == 2)
{
#if !CK_TILE_DISABLE_PACKED_FP32
@@ -126,11 +146,19 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
});
}
else if constexpr(Phase == 3) {}
else if constexpr(Phase == 3)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
}
else
{
if constexpr(Phase == 0) {}
if constexpr(Phase == 0)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
else if constexpr(Phase == 1)
{
static_for<0, 8, 1>{}([&](auto) {
@@ -139,7 +167,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
});
}
else if constexpr(Phase == 2) {}
else if constexpr(Phase == 2)
{
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
}
else if constexpr(Phase == 3)
{
#if !CK_TILE_DISABLE_PACKED_FP32
@@ -177,6 +209,15 @@ CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
return result;
}
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
{
float result;
asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]"
: [result] "=v"(result)
: [lhs] "v"(lhs), [rhs] "v"(rhs));
return result;
}
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
{
fp16x2_t result;
@@ -466,7 +507,7 @@ struct BlockFmhaFwdV3Pipeline
statically_indexed_array<sp_compute_type, 2> sp;
decltype(gemm_1.MakeCBlockTile()) o_acc;
constexpr index_t fmha_alu_D_reg_cnt = 0; // threshold to decide how many fmha_alu_D_upd()
constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd()
// instructions should we move to fmha_alu1()
static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
@@ -631,8 +672,8 @@ struct BlockFmhaFwdV3Pipeline
// K_mem_su_ld_insts = 1 for 32 x 128
// V_mem_su_ld_insts = 1 for 128 x 32
static constexpr int K_mem_su_ld_insts = 1;
static constexpr int V_mem_su_ld_insts = 1;
constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
auto K_mem_load = [&](auto k_lds_write_idx) {
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
@@ -648,7 +689,6 @@ struct BlockFmhaFwdV3Pipeline
auto V_mem_load = [&](auto v_lds_write_idx) {
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
__builtin_amdgcn_sched_barrier(0);
/// FIXME: use the future-predicting method to move the window
move_tile_window(v_dram_window, {kK1, 0});
@@ -726,11 +766,12 @@ struct BlockFmhaFwdV3Pipeline
#else
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
#endif
// update partial o_acc [0, 2)
static_for<0, ck_tile::min(2, fmha_alu_D_reg_cnt), 1>{}(
[&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
// l{j}
/// Note: The compiler keeps moving the following instructions elsewhere because 'l'
/// is first consumed later. To anchor them here, we rewrite the final addition in
/// inline assembly to create a dependency, forcing the dependent instructions to
/// be emitted at this point.
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
@@ -739,13 +780,15 @@ struct BlockFmhaFwdV3Pipeline
l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]);
});
// update partial o_acc [2, fmha_alu_D_reg_cnt)
static_for<2, ck_tile::max(2, fmha_alu_D_reg_cnt), 1>{}(
[&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
// update partial o_acc [0, fmha_alu_D_reg_cnt)
static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) {
o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale);
});
/// NOTICE: Compiler keep moving the conversion instructions to other places. We rewite
/// the cast_tile() call into inline asm to force the conversion instructions to be
/// generated here. The fmha_alu1() call should be placed at the end of a phase.
/// Note: The compiler keeps sinking the conversion instructions because the
/// result 'p' is only consumed later. To anchor them here, we rewrite
/// the cast_tile() call as inline assembly, forcing the conversions to be
/// emitted at this point.
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) {
float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
@@ -763,6 +806,10 @@ struct BlockFmhaFwdV3Pipeline
sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
}
});
/// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly
/// can interfere with the behavior of sched_group_barrier(), so ending the phase here
/// avoids unintended reordering.
};
auto gemm = [&](auto sp_reg_idx, auto gemm_idx) {
@@ -937,9 +984,9 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
Scheduler::schedule(cl_p, number<1>{});
fmha_mask(xdl_SP_p01_reg_idx);
Scheduler::schedule(cl_p, number<1>{});
__builtin_amdgcn_sched_barrier(0);
// phase2
ASM_MARKER("phase2 Wave0-3");
@@ -947,6 +994,8 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_nop 0");
__builtin_amdgcn_sched_barrier(0);
cl_calc(xdl_SP_p23_reg_idx, gemm1);
Scheduler::schedule(cl_p, number<2>{});
@@ -995,6 +1044,8 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_nop 1");
__builtin_amdgcn_sched_barrier(0);
cl_calc(xdl_SP_p01_reg_idx, gemm0);
fmha_alu1(xdl_SP_p23_reg_idx);
@@ -1005,9 +1056,9 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
Scheduler::schedule(cl_p, number<2>{});
fmha_mask(xdl_SP_p01_reg_idx);
Scheduler::schedule(cl_p, number<2>{});
kv_token_start += kN0;
if(num_total_loop <= ++i_total_loops)
{
@@ -1021,6 +1072,8 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_nop 1");
__builtin_amdgcn_sched_barrier(0);
cl_calc(xdl_SP_p23_reg_idx, gemm1);
Scheduler::schedule(cl_p, number<3>{});
@@ -1036,7 +1089,14 @@ struct BlockFmhaFwdV3Pipeline
auto ps_pi = number<1>{} - d;
auto V_lds_rd_idx = ps_pi;
s_waitcnt_vmcnt<K_mem_su_ld_insts>();
if(1 < num_total_loop)
{
s_waitcnt_vmcnt<K_mem_su_ld_insts>();
}
else
{
s_waitcnt_vmcnt<0>();
}
__builtin_amdgcn_s_barrier();
V_lds_load(V_lds_rd_idx);
@@ -1102,14 +1162,14 @@ struct BlockFmhaFwdV3Pipeline
V_mem_load(number<1>{}); // V1
K_lds_load(number<1>{}); // K1
asm volatile("s_setprio 0");
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
while(core_loop(number<0>{}))
;
}
if(warp_group_id != 0)
{
asm volatile("s_setprio 1");
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_s_barrier();
while(core_loop(number<1>{}))
;
@@ -1167,14 +1227,13 @@ struct BlockFmhaFwdV3Pipeline
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
float scale_s,
void* smem_ptr) const
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
float scale_s,
void* smem_ptr) const
{
using namespace ck_tile;