[CK_TILE][FMHA] Add logits soft-capping support for FAv3 (WIP) (#3355)

* Let fmha_fwd_v3() compatible with fmha_fwd()

* Decouple get_fwd_blobs() and FmhaFwdKernel

* Decouple compatibility checks from get_fwd_blobs()

* Extract product feature checks out from get_fwd_blobs()

* Remove duplicated code in factories and redundant checks

* Remove FmhaFwdKernel<>::GetName()

* Let FmhaFwdApiPool support pipelines with different mask_impl

* Add tile setting for fmha fwd v3 pipeline

* Add fwd v3 instances to tile_example_fmha_fwd manually

* Remove unused function import

* Undo irrelevant changes

* Remove fwd v3 instances from tile_example_fmha_fwd

* Finish fmha fwd v3 kernel instance codegen

* Fix formatting

* Remove unused F_idx attribute

* Add is_generic_attention_mask<> traits

* Add constraints to the fmha fwd v3 pipeline

* Unify traits & problem used for fmha fwd v3

* Unify kernel launch code for fmha fwd v2 & v3

* Unify kernel template selection logic

* Use same kernel codegen template for both v2 & v3

* Rename api() property as render() method

* Allow specifying filter for fmha fwd api pool

* Allow specifying function name when rendering api pool items

* Separate fmha fwd v3 kernel dispatching logic from v2

* Remove lambda assignment

* Add simple v2/v3 dispatch logic

* Stop generating empty if-clauses

Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them.

* Use "".join() to concatenate fmha fwd api string content

* Add more feature checks for fmha fwd v3 pipeline

* Check features before dispatch to fmha_fwd_v3()

* Add more feature checks for fmha_fwd_v3()

* Add missing filter call

* Use Tuple to reserve the dtype orders

* Fix wrong pipeline matching logic

* Add fmha fwd v3 group mode instances

* Add functor_transform<>

* Add type constraints to make_tile_window()

* Remove fmha fwd v3 example

* Fix wrong product(aiter mha_fwd()) config

* Fix wrong fmha fwd v2/v3 selection logic

* Fix formatting

* Add comment to warning v3 kernel users

* Fix wrong codegen logics

* Remove unnecessary param

* Fix format

* Add logits soft-capping support for fmha fwd v3 pipeline (WIP)

* Add missing Kargs base type

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Po Yen Chen
2025-12-18 16:08:45 +08:00
committed by GitHub
parent bb8445dca8
commit bfac64953f
4 changed files with 154 additions and 28 deletions

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <type_traits>
#include <utility>
@@ -30,14 +31,16 @@ struct FmhaFwdV3Kernel
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
@@ -93,10 +96,33 @@ struct FmhaFwdV3Kernel
ck_tile::index_t batch_stride_lse = 0;
};
struct FmhaFwdLogitsSoftCapKargs
{
FmhaFwdLogitsSoftCapKargs() = default;
void init_logits_soft_cap(float logits_soft_cap_)
{
if(0 < logits_soft_cap_)
{
logits_soft_cap = logits_soft_cap_;
logits_soft_cap_rcp = 1.f / logits_soft_cap;
}
else
{
logits_soft_cap = 0.f;
logits_soft_cap_rcp = 0.f;
}
}
float logits_soft_cap;
float logits_soft_cap_rcp;
};
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
@@ -112,7 +138,8 @@ struct FmhaFwdV3Kernel
struct FmhaFwdGroupModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -127,6 +154,13 @@ struct FmhaFwdV3Kernel
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
struct BlockIndices
{
ck_tile::index_t batch_idx;
ck_tile::index_t qo_head_idx;
ck_tile::index_t kv_head_idx;
};
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
@@ -141,6 +175,7 @@ struct FmhaFwdV3Kernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
@@ -183,6 +218,7 @@ struct FmhaFwdV3Kernel
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for logits_soft_cap
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -201,6 +237,10 @@ struct FmhaFwdV3Kernel
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
}
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
@@ -223,6 +263,7 @@ struct FmhaFwdV3Kernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
@@ -260,6 +301,7 @@ struct FmhaFwdV3Kernel
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for logits_soft_cap
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
@@ -277,6 +319,10 @@ struct FmhaFwdV3Kernel
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
}
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
@@ -594,6 +640,21 @@ struct FmhaFwdV3Kernel
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
AttentionVariant variant;
const auto variant_params = [&] {
if constexpr(kHasLogitsSoftCap)
{
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
}
else
{
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
}
}();
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
auto o_acc_tile = [&]() {
return FmhaPipeline{}(q_dram_window,
k_dram_window,
@@ -601,6 +662,9 @@ struct FmhaFwdV3Kernel
lse_dram_window,
mask,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr);
}();

View File

@@ -264,6 +264,7 @@ struct BlockFmhaFwdV3Pipeline
using PDataType = ck_tile::remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = ck_tile::remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = ck_tile::remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = ck_tile::remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename Problem::FmhaMask>;
static_assert(is_generic_attention_mask_v<FmhaMask>);
@@ -298,8 +299,7 @@ struct BlockFmhaFwdV3Pipeline
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ;
static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS &&
!kStoreLSE && !kHasDropout &&
static_assert((BiasEnum == BlockAttentionBiasEnum::NO_BIAS && !kStoreLSE && !kHasDropout &&
(QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) &&
!kSkipMinSeqlenQ),
"enable unsupported features");
@@ -401,7 +401,9 @@ struct BlockFmhaFwdV3Pipeline
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction>
typename OAccElementFunction,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
@@ -415,6 +417,9 @@ struct BlockFmhaFwdV3Pipeline
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr) const
{
using namespace ck_tile;
@@ -721,6 +726,22 @@ struct BlockFmhaFwdV3Pipeline
/// TODO: remove the sp_delta and use sp_compute directly
statically_indexed_array<decltype(sp(number<0>{}).sp_compute), 2> sp_delta;
auto fmha_logits_trans = [&](auto sp_reg_idx) {
if constexpr(kHasLogitsSoftCap)
{
auto apply_logits_transform = [&variant, &variant_params, &block_indices](
auto& logits) {
logits = variant.LogitsTransform(variant_params,
variant.QueryTransform(variant_params, logits),
block_indices.batch_idx,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
};
tile_elementwise_inout(apply_logits_transform, sp(sp_reg_idx).sp_compute);
}
};
auto fmha_alu0 = [&](auto sp_reg_idx) {
m_old = m; // m{j-1}
static_assert(m.thread_buf_.size() == 1,
@@ -746,9 +767,17 @@ struct BlockFmhaFwdV3Pipeline
std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv(
sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if constexpr(kHasLogitsSoftCap)
{
sp_delta(sp_reg_idx)(i_j_idx) =
sp(sp_reg_idx).sp_compute(i_j_idx) - m(i_j_idx);
}
else
{
sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv(
sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
}
});
});
/// TODO: move some fmha_alu1() code here if necessary
@@ -793,8 +822,16 @@ struct BlockFmhaFwdV3Pipeline
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
const auto tmp = [&] {
if constexpr(kHasLogitsSoftCap)
{
return ck_tile::exp2(m_old[i_idx] - m[i_idx]);
}
else
{
return ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
}
}();
l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]);
});
@@ -880,7 +917,16 @@ struct BlockFmhaFwdV3Pipeline
};
auto fmha_alu_D_upd = [&] {
o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
o_acc_scale = [&] {
if constexpr(kHasLogitsSoftCap)
{
return ck_tile::exp2(m_old.thread_buf_[0] - m.thread_buf_[0]);
}
else
{
return ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
}
}();
fp32x2_t pk_o_acc_scale;
pk_o_acc_scale.x = o_acc_scale;
@@ -928,7 +974,12 @@ struct BlockFmhaFwdV3Pipeline
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = kv_token_start + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
}
@@ -992,6 +1043,7 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_sched_barrier(0);
cl_calc(xdl_SP_p01_reg_idx, gemm0);
fmha_alu1(xdl_SP_p23_reg_idx);
fmha_logits_trans(xdl_SP_p01_reg_idx);
Scheduler::schedule(cl_p, number<0>{});
__builtin_amdgcn_sched_barrier(0);
@@ -1066,6 +1118,7 @@ struct BlockFmhaFwdV3Pipeline
__builtin_amdgcn_sched_barrier(0);
cl_calc(xdl_SP_p01_reg_idx, gemm0);
fmha_alu1(xdl_SP_p23_reg_idx);
fmha_logits_trans(xdl_SP_p01_reg_idx);
Scheduler::schedule(cl_p, number<1>{});
__builtin_amdgcn_sched_barrier(0);
@@ -1149,7 +1202,7 @@ struct BlockFmhaFwdV3Pipeline
// (3) mfma (Q*K0) + softmax
gemm(number<0>{}, /*gemm_idx=*/number<0>{});
fmha_logits_trans(number<0>{});
fmha_mask(number<0>{});
/// TODO: find better way to map fmha_alu(0,96) call
fmha_alu0(number<0>{});
@@ -1244,13 +1297,18 @@ struct BlockFmhaFwdV3Pipeline
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
typename LSEDramBlockWindowTmp,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr) const
{
using namespace ck_tile;
@@ -1268,6 +1326,9 @@ struct BlockFmhaFwdV3Pipeline
identity{},
mask,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr);
}
};