[CK_TILE] FMHA FAv3 scheduling fine-tuning for performance (#2833)

* Re-mapping thread block indices for causal=True kernels

* Use more intuitive remap_opt value

* Fallback to origin remapping if seqlen_q >= 64K

* Use GenericAttentionMask to reduce mask computation

* Avoid unnecessary boundary check for IsMasking=false case

* Fix wrong kernel entry specifier

* Add s_nop to prevent delay wave0-3

* Refine scheduling

* Remove unnecessary sched_group_barrier()

* Move sched_group_barrier() call to scheduler

* Replace inline asm s_setprio with intrinsics

* Rephrase comments

* Expend some o_acc rescaling insts to avoid SIMD idle

* Fix block idx special mapping logic

* Tune block index mapping for causal=False cases

* Tune block index mapping for causal=True cases

* Fix wrong vmcnt()

* Remove parameter name

* Use boolean option for turn on/off causal mask

* Update benchmark_fwd_v3.sh option usages

* Add option if compiler support it
This commit is contained in:
Po Yen Chen
2025-09-16 11:32:38 +08:00
committed by GitHub
parent 7d7ded62d3
commit 7fbc9d6c97
8 changed files with 250 additions and 112 deletions

View File

@@ -81,6 +81,7 @@ struct FmhaFwdV3Kernel
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::GenericAttentionMaskEnum mask_type;
ck_tile::index_t remap_opt;
};
struct FmhaFwdCommonLSEKargs
@@ -143,7 +144,8 @@ struct FmhaFwdV3Kernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -176,6 +178,7 @@ struct FmhaFwdV3Kernel
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
kargs.remap_opt = remap_opt;
}
if constexpr(kStoreLSE)
{
@@ -213,7 +216,8 @@ struct FmhaFwdV3Kernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -245,6 +249,7 @@ struct FmhaFwdV3Kernel
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
kargs.remap_opt = remap_opt;
}
if constexpr(kStoreLSE)
{
@@ -261,39 +266,81 @@ struct FmhaFwdV3Kernel
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
nhead_,
batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
{
using namespace ck_tile;
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
return dim3(nhead_,
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
batch_size_);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
return dim3(nhead_,
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
batch_size_);
}
}
CK_TILE_DEVICE static constexpr auto
RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
{
if(remap_option < 1)
{
return make_tuple(static_cast<int32_t>(gridDim.x - tg_idx - 1), tg_idy);
}
int32_t remapped_tg_idx = tg_idx;
int32_t remapped_tg_idy = tg_idy;
if(remap_option == 2)
{ // special remapping
int32_t tmp0 = (remapped_tg_idy & 0x7) * gridDim.x + remapped_tg_idx;
int32_t tmp1 = tmp0 & 0x7;
remapped_tg_idx = tmp0 >> 3;
remapped_tg_idy = (remapped_tg_idy & 0xfffffff8) + tmp1;
}
else
{ // normal remapping
int32_t cus_per_xdim_per_xcc = gridDim.x >> 3;
int32_t tgs_cu_id = remapped_tg_idx >> 3;
if(tgs_cu_id < cus_per_xdim_per_xcc)
{
int32_t tgs_xcc_id = remapped_tg_idx & 0x7;
int32_t new_tg_idx = tgs_xcc_id * cus_per_xdim_per_xcc + tgs_cu_id;
remapped_tg_idx = new_tg_idx;
}
}
return make_tuple(remapped_tg_idx, remapped_tg_idy);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs&)
{
using namespace ck_tile;
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
// FmhaPipeline::kN1);
// assume that num_tile_n1 is always 1
if constexpr(kHasMask)
{
const index_t i_nhead = blockIdx.x;
const index_t i_block = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
}
else
{
const index_t i_nhead = blockIdx.x;
const index_t i_block = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
}
}