[CK_TILE][FMHA] Add logits soft-capping support for FAv3 (WIP) (#3355)

* Let fmha_fwd_v3() compatible with fmha_fwd()

* Decouple get_fwd_blobs() and FmhaFwdKernel

* Decouple compatibility checks from get_fwd_blobs()

* Extract product feature checks out from get_fwd_blobs()

* Remove duplicated code in factories and redundant checks

* Remove FmhaFwdKernel<>::GetName()

* Let FmhaFwdApiPool support pipelines with different mask_impl

* Add tile setting for fmha fwd v3 pipeline

* Add fwd v3 instances to tile_example_fmha_fwd manually

* Remove unused function import

* Undo irrelevant changes

* Remove fwd v3 instances from tile_example_fmha_fwd

* Finish fmha fwd v3 kernel instance codegen

* Fix formatting

* Remove unused F_idx attribute

* Add is_generic_attention_mask<> traits

* Add constraints to the fmha fwd v3 pipeline

* Unify traits & problem used for fmha fwd v3

* Unify kernel launch code for fmha fwd v2 & v3

* Unify kernel template selection logic

* Use same kernel codegen template for both v2 & v3

* Rename api() property as render() method

* Allow specifying filter for fmha fwd api pool

* Allow specifying function name when rendering api pool items

* Separate fmha fwd v3 kernel dispatching logic from v2

* Remove lambda assignment

* Add simple v2/v3 dispatch logic

* Stop generating empty if-clauses

Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them.

* Use "".join() to concatenate fmha fwd api string content

* Add more feature checks for fmha fwd v3 pipeline

* Check features before dispatch to fmha_fwd_v3()

* Add more feature checks for fmha_fwd_v3()

* Add missing filter call

* Use Tuple to reserve the dtype orders

* Fix wrong pipeline matching logic

* Add fmha fwd v3 group mode instances

* Add functor_transform<>

* Add type constraints to make_tile_window()

* Remove fmha fwd v3 example

* Fix wrong product(aiter mha_fwd()) config

* Fix wrong fmha fwd v2/v3 selection logic

* Fix formatting

* Add comment to warning v3 kernel users

* Fix wrong codegen logics

* Remove unnecessary param

* Fix format

* Add logits soft-capping support for fmha fwd v3 pipeline (WIP)

* Add missing Kargs base type

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Po Yen Chen
2025-12-18 16:08:45 +08:00
committed by GitHub
parent bb8445dca8
commit bfac64953f
4 changed files with 154 additions and 28 deletions

View File

@@ -264,6 +264,7 @@ struct BlockFmhaFwdV3Pipeline
using PDataType = ck_tile::remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = ck_tile::remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = ck_tile::remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = ck_tile::remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename Problem::FmhaMask>;
static_assert(is_generic_attention_mask_v<FmhaMask>);
@@ -298,8 +299,7 @@ struct BlockFmhaFwdV3Pipeline
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ;
static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS &&
!kStoreLSE && !kHasDropout &&
static_assert((BiasEnum == BlockAttentionBiasEnum::NO_BIAS && !kStoreLSE && !kHasDropout &&
(QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) &&
!kSkipMinSeqlenQ),
"enable unsupported features");
@@ -401,7 +401,9 @@ struct BlockFmhaFwdV3Pipeline
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction>
typename OAccElementFunction,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
@@ -415,6 +417,9 @@ struct BlockFmhaFwdV3Pipeline
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr) const
{
using namespace ck_tile;
@@ -721,6 +726,22 @@ struct BlockFmhaFwdV3Pipeline
/// TODO: remove the sp_delta and use sp_compute directly
statically_indexed_array<decltype(sp(number<0>{}).sp_compute), 2> sp_delta;
auto fmha_logits_trans = [&](auto sp_reg_idx) {
if constexpr(kHasLogitsSoftCap)
{
auto apply_logits_transform = [&variant, &variant_params, &block_indices](
auto& logits) {
logits = variant.LogitsTransform(variant_params,
variant.QueryTransform(variant_params, logits),
block_indices.batch_idx,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
};
tile_elementwise_inout(apply_logits_transform, sp(sp_reg_idx).sp_compute);
}
};
auto fmha_alu0 = [&](auto sp_reg_idx) {
m_old = m; // m{j-1}
static_assert(m.thread_buf_.size() == 1,
@@ -746,9 +767,17 @@ struct BlockFmhaFwdV3Pipeline
std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv(
sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if constexpr(kHasLogitsSoftCap)
{
sp_delta(sp_reg_idx)(i_j_idx) =
sp(sp_reg_idx).sp_compute(i_j_idx) - m(i_j_idx);
}
else
{
sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv(
sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
}
});
});
/// TODO: move some fmha_alu1() code here if necessary
@@ -793,8 +822,16 @@ struct BlockFmhaFwdV3Pipeline
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
const auto tmp = [&] {
if constexpr(kHasLogitsSoftCap)
{
return ck_tile::exp2(m_old[i_idx] - m[i_idx]);
}
else
{
return ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
}
}();
l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]);
});
@@ -880,7 +917,16 @@ struct BlockFmhaFwdV3Pipeline
};
auto fmha_alu_D_upd = [&] {
o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
o_acc_scale = [&] {
if constexpr(kHasLogitsSoftCap)
{
return ck_tile::exp2(m_old.thread_buf_[0] - m.thread_buf_[0]);
}
else
{
return ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
}
}();
fp32x2_t pk_o_acc_scale;
pk_o_acc_scale.x = o_acc_scale;
@@ -928,7 +974,12 @@ struct BlockFmhaFwdV3Pipeline
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = kv_token_start + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
}
@@ -992,6 +1043,7 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_sched_barrier(0);
cl_calc(xdl_SP_p01_reg_idx, gemm0);
fmha_alu1(xdl_SP_p23_reg_idx);
fmha_logits_trans(xdl_SP_p01_reg_idx);
Scheduler::schedule(cl_p, number<0>{});
__builtin_amdgcn_sched_barrier(0);
@@ -1066,6 +1118,7 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_sched_barrier(0);
cl_calc(xdl_SP_p01_reg_idx, gemm0);
fmha_alu1(xdl_SP_p23_reg_idx);
fmha_logits_trans(xdl_SP_p01_reg_idx);
Scheduler::schedule(cl_p, number<1>{});
__builtin_amdgcn_sched_barrier(0);
@@ -1149,7 +1202,7 @@ struct BlockFmhaFwdV3Pipeline
// (3) mfma (Q*K0) + softmax
gemm(number<0>{}, /*gemm_idx=*/number<0>{});
fmha_logits_trans(number<0>{});
fmha_mask(number<0>{});
/// TODO: find better way to map fmha_alu(0,96) call
fmha_alu0(number<0>{});
@@ -1244,13 +1297,18 @@ struct BlockFmhaFwdV3Pipeline
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
typename LSEDramBlockWindowTmp,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr) const
{
using namespace ck_tile;
@@ -1268,6 +1326,9 @@ struct BlockFmhaFwdV3Pipeline
identity{},
mask,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr);
}
};