mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] FMHA FAv3 scheduling fine-tuning for performance (#2833)
* Re-mapping thread block indices for causal=True kernels * Use more intuitive remap_opt value * Fallback to origin remapping if seqlen_q >= 64K * Use GenericAttentionMask to reduce mask computation * Avoid unnecessary boundary check for IsMasking=false case * Fix wrong kernel entry specifier * Add s_nop to prevent delay wave0-3 * Refine scheduling * Remove unnecessary sched_group_barrier() * Move sched_group_barrier() call to scheduler * Replace inline asm s_setprio with intrinsics * Rephrase comments * Expend some o_acc rescaling insts to avoid SIMD idle * Fix block idx special mapping logic * Tune block index mapping for causal=False cases * Tune block index mapping for causal=True cases * Fix wrong vmcnt() * Remove parameter name * Use boolean option for turn on/off causal mask * Update benchmark_fwd_v3.sh option usages * Add option if compiler support it
This commit is contained in:
@@ -81,6 +81,7 @@ struct FmhaFwdV3Kernel
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
ck_tile::index_t remap_opt;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonLSEKargs
|
||||
@@ -143,7 +144,8 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t remap_opt)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -176,6 +178,7 @@ struct FmhaFwdV3Kernel
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
kargs.remap_opt = remap_opt;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -213,7 +216,8 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t remap_opt)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -245,6 +249,7 @@ struct FmhaFwdV3Kernel
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
kargs.remap_opt = remap_opt;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -261,39 +266,81 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t hdim_v_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
return dim3(nhead_,
|
||||
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
|
||||
batch_size_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
return dim3(nhead_,
|
||||
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
|
||||
batch_size_);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
|
||||
{
|
||||
if(remap_option < 1)
|
||||
{
|
||||
return make_tuple(static_cast<int32_t>(gridDim.x - tg_idx - 1), tg_idy);
|
||||
}
|
||||
|
||||
int32_t remapped_tg_idx = tg_idx;
|
||||
int32_t remapped_tg_idy = tg_idy;
|
||||
|
||||
if(remap_option == 2)
|
||||
{ // special remapping
|
||||
int32_t tmp0 = (remapped_tg_idy & 0x7) * gridDim.x + remapped_tg_idx;
|
||||
int32_t tmp1 = tmp0 & 0x7;
|
||||
|
||||
remapped_tg_idx = tmp0 >> 3;
|
||||
remapped_tg_idy = (remapped_tg_idy & 0xfffffff8) + tmp1;
|
||||
}
|
||||
else
|
||||
{ // normal remapping
|
||||
int32_t cus_per_xdim_per_xcc = gridDim.x >> 3;
|
||||
int32_t tgs_cu_id = remapped_tg_idx >> 3;
|
||||
|
||||
if(tgs_cu_id < cus_per_xdim_per_xcc)
|
||||
{
|
||||
int32_t tgs_xcc_id = remapped_tg_idx & 0x7;
|
||||
int32_t new_tg_idx = tgs_xcc_id * cus_per_xdim_per_xcc + tgs_cu_id;
|
||||
|
||||
remapped_tg_idx = new_tg_idx;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(remapped_tg_idx, remapped_tg_idy);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs&)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
|
||||
// FmhaPipeline::kN1);
|
||||
|
||||
// assume that num_tile_n1 is always 1
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
const index_t i_nhead = blockIdx.x;
|
||||
const index_t i_block = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t i_nhead = blockIdx.x;
|
||||
const index_t i_block = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user