[CK_TILE][FMHA] Add logits soft-capping support for FAv3 (WIP) (#3355)

* Let fmha_fwd_v3() compatible with fmha_fwd()

* Decouple get_fwd_blobs() and FmhaFwdKernel

* Decouple compatibility checks from get_fwd_blobs()

* Extract product feature checks out from get_fwd_blobs()

* Remove duplicated code in factories and redundant checks

* Remove FmhaFwdKernel<>::GetName()

* Let FmhaFwdApiPool support pipelines with different mask_impl

* Add tile setting for fmha fwd v3 pipeline

* Add fwd v3 instances to tile_example_fmha_fwd manually

* Remove unused function import

* Undo irrelevant changes

* Remove fwd v3 instances from tile_example_fmha_fwd

* Finish fmha fwd v3 kernel instance codegen

* Fix formatting

* Remove unused F_idx attribute

* Add is_generic_attention_mask<> traits

* Add constraints to the fmha fwd v3 pipeline

* Unify traits & problem used for fmha fwd v3

* Unify kernel launch code for fmha fwd v2 & v3

* Unify kernel template selection logic

* Use same kernel codegen template for both v2 & v3

* Rename api() property as render() method

* Allow specifying filter for fmha fwd api pool

* Allow specifying function name when rendering api pool items

* Separate fmha fwd v3 kernel dispatching logic from v2

* Remove lambda assignment

* Add simple v2/v3 dispatch logic

* Stop generating empty if-clauses

Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them.

* Use "".join() to concatenate fmha fwd api string content

* Add more feature checks for fmha fwd v3 pipeline

* Check features before dispatch to fmha_fwd_v3()

* Add more feature checks for fmha_fwd_v3()

* Add missing filter call

* Use Tuple to reserve the dtype orders

* Fix wrong pipeline matching logic

* Add fmha fwd v3 group mode instances

* Add functor_transform<>

* Add type constraints to make_tile_window()

* Remove fmha fwd v3 example

* Fix wrong product(aiter mha_fwd()) config

* Fix wrong fmha fwd v2/v3 selection logic

* Fix formatting

* Add comment to warning v3 kernel users

* Fix wrong codegen logics

* Remove unnecessary param

* Fix format

* Add logits soft-capping support for fmha fwd v3 pipeline (WIP)

* Add missing Kargs base type

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Po Yen Chen
2025-12-18 16:08:45 +08:00
committed by GitHub
parent bb8445dca8
commit bfac64953f
4 changed files with 154 additions and 28 deletions

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <type_traits>
#include <utility>
@@ -30,14 +31,16 @@ struct FmhaFwdV3Kernel
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
@@ -93,10 +96,33 @@ struct FmhaFwdV3Kernel
ck_tile::index_t batch_stride_lse = 0;
};
struct FmhaFwdLogitsSoftCapKargs
{
FmhaFwdLogitsSoftCapKargs() = default;
void init_logits_soft_cap(float logits_soft_cap_)
{
if(0 < logits_soft_cap_)
{
logits_soft_cap = logits_soft_cap_;
logits_soft_cap_rcp = 1.f / logits_soft_cap;
}
else
{
logits_soft_cap = 0.f;
logits_soft_cap_rcp = 0.f;
}
}
float logits_soft_cap;
float logits_soft_cap_rcp;
};
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
@@ -112,7 +138,8 @@ struct FmhaFwdV3Kernel
struct FmhaFwdGroupModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -127,6 +154,13 @@ struct FmhaFwdV3Kernel
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
struct BlockIndices
{
ck_tile::index_t batch_idx;
ck_tile::index_t qo_head_idx;
ck_tile::index_t kv_head_idx;
};
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
@@ -141,6 +175,7 @@ struct FmhaFwdV3Kernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
@@ -183,6 +218,7 @@ struct FmhaFwdV3Kernel
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for logits_soft_cap
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -201,6 +237,10 @@ struct FmhaFwdV3Kernel
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
}
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
@@ -223,6 +263,7 @@ struct FmhaFwdV3Kernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
@@ -260,6 +301,7 @@ struct FmhaFwdV3Kernel
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for logits_soft_cap
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
@@ -277,6 +319,10 @@ struct FmhaFwdV3Kernel
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
}
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
@@ -594,6 +640,21 @@ struct FmhaFwdV3Kernel
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
AttentionVariant variant;
const auto variant_params = [&] {
if constexpr(kHasLogitsSoftCap)
{
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
}
else
{
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
}
}();
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
auto o_acc_tile = [&]() {
return FmhaPipeline{}(q_dram_window,
k_dram_window,
@@ -601,6 +662,9 @@ struct FmhaFwdV3Kernel
lse_dram_window,
mask,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr);
}();