mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[CK_TILE] Add logits soft-capping & customization support to the FMHA forward kernel/pipelines (#2163)
* hack for cap logits * fix bug * Re-format files * Allow specifying logits_soft_cap through APIs * Support turn on/off logits_soft_cap in async pipeline * Do not generate non-verified kernels * Align receipt used in Aiter * Sync logits soft-capping across pipelines * Re-enable some hdim pipelines * fix perf * Add attention variant for logits_soft_cap * Add newline at end-of-file * Fix performance * Add comment to explain logits_soft_cap pre-processing * Unify code * Unify floating-point literal style * Use class data member to slience the compilation error * [CK_TILE] Update attention customizaton interface: add LogitsMask() (#2133) * Send 'mask' along with variant params to the LogitsMask() * Send block indices to the variant * Add indices parameters in variant interface * Fix fmha bwd codegen error * Allow switch logits_soft_cap impl * Eliminate register spills * Fix compilation errors * Fix wrong LSE * Fix LSE for splitkv kernel * Sync splitkv pipeline changes * Add batch_prefill kernel/pipeline * Fix codegen error * Undo changes in CMakeLists.txt * Merge pipeline filtering check * Use different code path if kHasLogitsSoftCap=false * Remove [[maybe_unused]] attribute * Use pre-existing compile-time flag to instantiate templates * Sync pipeline changes * Update CHANGELOG.md --------- Co-authored-by: Bernard <bernaliu@amd.com> Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
1134
include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp
Normal file
1134
include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
@@ -47,11 +48,13 @@ struct FmhaFwdKernel
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
|
||||
@@ -94,7 +97,7 @@ struct FmhaFwdKernel
|
||||
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
@@ -139,6 +142,28 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
};
|
||||
|
||||
struct FmhaFwdLogitsSoftCapKargs
|
||||
{
|
||||
FmhaFwdLogitsSoftCapKargs() = default;
|
||||
|
||||
void init_logits_soft_cap(float logits_soft_cap_)
|
||||
{
|
||||
if(0 < logits_soft_cap_)
|
||||
{
|
||||
logits_soft_cap = logits_soft_cap_;
|
||||
logits_soft_cap_rcp = 1.f / logits_soft_cap;
|
||||
}
|
||||
else
|
||||
{
|
||||
logits_soft_cap = 0.f;
|
||||
logits_soft_cap_rcp = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
float logits_soft_cap;
|
||||
float logits_soft_cap_rcp;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonBiasKargs
|
||||
{
|
||||
const void* bias_ptr = nullptr;
|
||||
@@ -242,7 +267,8 @@ struct FmhaFwdKernel
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -260,7 +286,8 @@ struct FmhaFwdKernel
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -269,6 +296,13 @@ struct FmhaFwdKernel
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
|
||||
|
||||
struct BlockIndices
|
||||
{
|
||||
ck_tile::index_t batch_idx;
|
||||
ck_tile::index_t qo_head_idx;
|
||||
ck_tile::index_t kv_head_idx;
|
||||
};
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
@@ -287,6 +321,7 @@ struct FmhaFwdKernel
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
@@ -343,6 +378,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -398,6 +434,10 @@ struct FmhaFwdKernel
|
||||
kargs.batch_stride_randval = batch_stride_randval;
|
||||
kargs.is_store_randval = s_randval;
|
||||
}
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -421,6 +461,7 @@ struct FmhaFwdKernel
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
@@ -465,6 +506,7 @@ struct FmhaFwdKernel
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
@@ -512,6 +554,7 @@ struct FmhaFwdKernel
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
@@ -556,6 +599,7 @@ struct FmhaFwdKernel
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
@@ -603,6 +647,7 @@ struct FmhaFwdKernel
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
@@ -652,6 +697,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
@@ -703,6 +749,10 @@ struct FmhaFwdKernel
|
||||
kargs.nhead_stride_randval = nhead_stride_randval;
|
||||
kargs.is_store_randval = s_randval;
|
||||
}
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -727,6 +777,7 @@ struct FmhaFwdKernel
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
@@ -765,6 +816,7 @@ struct FmhaFwdKernel
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
@@ -806,6 +858,7 @@ struct FmhaFwdKernel
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
@@ -844,6 +897,7 @@ struct FmhaFwdKernel
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
@@ -1307,6 +1361,21 @@ struct FmhaFwdKernel
|
||||
}
|
||||
}();
|
||||
|
||||
AttentionVariant variant;
|
||||
const auto variant_params = [&] {
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
|
||||
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
|
||||
}
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
@@ -1328,6 +1397,9 @@ struct FmhaFwdKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
@@ -1342,6 +1414,9 @@ struct FmhaFwdKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -43,14 +45,15 @@ struct FmhaFwdSplitKVKernel
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
||||
static constexpr bool kMergeNumHeadGroupsSeqLenQ =
|
||||
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
|
||||
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
static_assert(!kMergeNumHeadGroupsSeqLenQ ||
|
||||
@@ -95,7 +98,7 @@ struct FmhaFwdSplitKVKernel
|
||||
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
|
||||
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
|
||||
#undef _SS_
|
||||
@@ -150,6 +153,28 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t split_stride_o_acc;
|
||||
};
|
||||
|
||||
struct LogitsSoftCapKargs
|
||||
{
|
||||
LogitsSoftCapKargs() = default;
|
||||
|
||||
void init_logits_soft_cap(float logits_soft_cap_)
|
||||
{
|
||||
if(0 < logits_soft_cap_)
|
||||
{
|
||||
logits_soft_cap = logits_soft_cap_;
|
||||
logits_soft_cap_rcp = 1.f / logits_soft_cap;
|
||||
}
|
||||
else
|
||||
{
|
||||
logits_soft_cap = 0.f;
|
||||
logits_soft_cap_rcp = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
float logits_soft_cap;
|
||||
float logits_soft_cap_rcp;
|
||||
};
|
||||
|
||||
struct CommonBiasKargs
|
||||
{
|
||||
const void* bias_ptr = nullptr;
|
||||
@@ -207,7 +232,8 @@ struct FmhaFwdSplitKVKernel
|
||||
EmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
|
||||
std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>
|
||||
std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>,
|
||||
std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<3>>
|
||||
{
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
@@ -229,7 +255,8 @@ struct FmhaFwdSplitKVKernel
|
||||
EmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
|
||||
std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>
|
||||
std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<4>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -243,6 +270,13 @@ struct FmhaFwdSplitKVKernel
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
|
||||
|
||||
struct BlockIndices
|
||||
{
|
||||
ck_tile::index_t batch_idx;
|
||||
ck_tile::index_t qo_head_idx;
|
||||
ck_tile::index_t kv_head_idx;
|
||||
};
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
@@ -268,6 +302,7 @@ struct FmhaFwdSplitKVKernel
|
||||
const void* cache_batch_idx,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
@@ -324,6 +359,7 @@ struct FmhaFwdSplitKVKernel
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for paged-block table or cache_batch_idx
|
||||
{}, // placeholder for logits_soft_cap
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
@@ -363,6 +399,10 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
|
||||
}
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -392,6 +432,7 @@ struct FmhaFwdSplitKVKernel
|
||||
bool is_gappy,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
@@ -444,6 +485,7 @@ struct FmhaFwdSplitKVKernel
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for paged-block table
|
||||
{}, // placeholder for logits_soft_cap
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
@@ -478,6 +520,10 @@ struct FmhaFwdSplitKVKernel
|
||||
kargs.page_block_size = page_block_size;
|
||||
kargs.is_gappy = is_gappy;
|
||||
}
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -968,6 +1014,21 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
}();
|
||||
|
||||
AttentionVariant variant;
|
||||
const auto variant_params = [&] {
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
|
||||
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
|
||||
}
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead_k};
|
||||
|
||||
auto o_acc_tile = [&, i_split_ = i_split]() {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
@@ -991,6 +1052,9 @@ struct FmhaFwdSplitKVKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
@@ -1008,6 +1072,9 @@ struct FmhaFwdSplitKVKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user