mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE][FMHA] Add logits soft-capping support for FAv3 (WIP) (#3355)
* Let fmha_fwd_v3() compatible with fmha_fwd() * Decouple get_fwd_blobs() and FmhaFwdKernel * Decouple compatibility checks from get_fwd_blobs() * Extract product feature checks out from get_fwd_blobs() * Remove duplicated code in factories and redundant checks * Remove FmhaFwdKernel<>::GetName() * Let FmhaFwdApiPool support pipelines with different mask_impl * Add tile setting for fmha fwd v3 pipeline * Add fwd v3 instances to tile_example_fmha_fwd manually * Remove unused function import * Undo irrelevant changes * Remove fwd v3 instances from tile_example_fmha_fwd * Finish fmha fwd v3 kernel instance codegen * Fix formatting * Remove unused F_idx attribute * Add is_generic_attention_mask<> traits * Add constraints to the fmha fwd v3 pipeline * Unify traits & problem used for fmha fwd v3 * Unify kernel launch code for fmha fwd v2 & v3 * Unify kernel template selection logic * Use same kernel codegen template for both v2 & v3 * Rename api() property as render() method * Allow specifying filter for fmha fwd api pool * Allow specifying function name when rendering api pool items * Separate fmha fwd v3 kernel dispatching logic from v2 * Remove lambda assignment * Add simple v2/v3 dispatch logic * Stop generating empty if-clauses Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them. * Use "".join() to concatenate fmha fwd api string content * Add more feature checks for fmha fwd v3 pipeline * Check features before dispatch to fmha_fwd_v3() * Add more feature checks for fmha_fwd_v3() * Add missing filter call * Use Tuple to reserve the dtype orders * Fix wrong pipeline matching logic * Add fmha fwd v3 group mode instances * Add functor_transform<> * Add type constraints to make_tile_window() * Remove fmha fwd v3 example * Fix wrong product(aiter mha_fwd()) config * Fix wrong fmha fwd v2/v3 selection logic * Fix formatting * Add comment to warning v3 kernel users * Fix wrong codegen logics * Remove unnecessary param * Fix format * Add logits soft-capping support for fmha fwd v3 pipeline (WIP) * Add missing Kargs base type --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
@@ -30,14 +31,16 @@ struct FmhaFwdV3Kernel
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
@@ -93,10 +96,33 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t batch_stride_lse = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdLogitsSoftCapKargs
|
||||
{
|
||||
FmhaFwdLogitsSoftCapKargs() = default;
|
||||
|
||||
void init_logits_soft_cap(float logits_soft_cap_)
|
||||
{
|
||||
if(0 < logits_soft_cap_)
|
||||
{
|
||||
logits_soft_cap = logits_soft_cap_;
|
||||
logits_soft_cap_rcp = 1.f / logits_soft_cap;
|
||||
}
|
||||
else
|
||||
{
|
||||
logits_soft_cap = 0.f;
|
||||
logits_soft_cap_rcp = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
float logits_soft_cap;
|
||||
float logits_soft_cap_rcp;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeKargs
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -112,7 +138,8 @@ struct FmhaFwdV3Kernel
|
||||
struct FmhaFwdGroupModeKargs
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -127,6 +154,13 @@ struct FmhaFwdV3Kernel
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
|
||||
|
||||
struct BlockIndices
|
||||
{
|
||||
ck_tile::index_t batch_idx;
|
||||
ck_tile::index_t qo_head_idx;
|
||||
ck_tile::index_t kv_head_idx;
|
||||
};
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
@@ -141,6 +175,7 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
@@ -183,6 +218,7 @@ struct FmhaFwdV3Kernel
|
||||
nhead_stride_o}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for logits_soft_cap
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -201,6 +237,10 @@ struct FmhaFwdV3Kernel
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
|
||||
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
|
||||
@@ -223,6 +263,7 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
@@ -260,6 +301,7 @@ struct FmhaFwdV3Kernel
|
||||
nhead_stride_o}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for logits_soft_cap
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
@@ -277,6 +319,10 @@ struct FmhaFwdV3Kernel
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
|
||||
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
|
||||
@@ -594,6 +640,21 @@ struct FmhaFwdV3Kernel
|
||||
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
|
||||
}();
|
||||
|
||||
AttentionVariant variant;
|
||||
const auto variant_params = [&] {
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
|
||||
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
|
||||
}
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
@@ -601,6 +662,9 @@ struct FmhaFwdV3Kernel
|
||||
lse_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr);
|
||||
}();
|
||||
|
||||
|
||||
@@ -264,6 +264,7 @@ struct BlockFmhaFwdV3Pipeline
|
||||
using PDataType = ck_tile::remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = ck_tile::remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename Problem::ODataType>;
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename Problem::FmhaMask>;
|
||||
static_assert(is_generic_attention_mask_v<FmhaMask>);
|
||||
|
||||
@@ -298,8 +299,7 @@ struct BlockFmhaFwdV3Pipeline
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ;
|
||||
static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS &&
|
||||
!kStoreLSE && !kHasDropout &&
|
||||
static_assert((BiasEnum == BlockAttentionBiasEnum::NO_BIAS && !kStoreLSE && !kHasDropout &&
|
||||
(QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) &&
|
||||
!kSkipMinSeqlenQ),
|
||||
"enable unsupported features");
|
||||
@@ -401,7 +401,9 @@ struct BlockFmhaFwdV3Pipeline
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction>
|
||||
typename OAccElementFunction,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
@@ -415,6 +417,9 @@ struct BlockFmhaFwdV3Pipeline
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
@@ -721,6 +726,22 @@ struct BlockFmhaFwdV3Pipeline
|
||||
/// TODO: remove the sp_delta and use sp_compute directly
|
||||
statically_indexed_array<decltype(sp(number<0>{}).sp_compute), 2> sp_delta;
|
||||
|
||||
auto fmha_logits_trans = [&](auto sp_reg_idx) {
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
auto apply_logits_transform = [&variant, &variant_params, &block_indices](
|
||||
auto& logits) {
|
||||
logits = variant.LogitsTransform(variant_params,
|
||||
variant.QueryTransform(variant_params, logits),
|
||||
block_indices.batch_idx,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
};
|
||||
|
||||
tile_elementwise_inout(apply_logits_transform, sp(sp_reg_idx).sp_compute);
|
||||
}
|
||||
};
|
||||
|
||||
auto fmha_alu0 = [&](auto sp_reg_idx) {
|
||||
m_old = m; // m{j-1}
|
||||
static_assert(m.thread_buf_.size() == 1,
|
||||
@@ -746,9 +767,17 @@ struct BlockFmhaFwdV3Pipeline
|
||||
std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv(
|
||||
sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
sp_delta(sp_reg_idx)(i_j_idx) =
|
||||
sp(sp_reg_idx).sp_compute(i_j_idx) - m(i_j_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv(
|
||||
sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
|
||||
}
|
||||
});
|
||||
});
|
||||
/// TODO: move some fmha_alu1() code here if necessary
|
||||
@@ -793,8 +822,16 @@ struct BlockFmhaFwdV3Pipeline
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
|
||||
|
||||
const auto tmp = [&] {
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return ck_tile::exp2(m_old[i_idx] - m[i_idx]);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
|
||||
}
|
||||
}();
|
||||
l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]);
|
||||
});
|
||||
|
||||
@@ -880,7 +917,16 @@ struct BlockFmhaFwdV3Pipeline
|
||||
};
|
||||
|
||||
auto fmha_alu_D_upd = [&] {
|
||||
o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
|
||||
o_acc_scale = [&] {
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return ck_tile::exp2(m_old.thread_buf_[0] - m.thread_buf_[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
|
||||
}
|
||||
}();
|
||||
|
||||
fp32x2_t pk_o_acc_scale;
|
||||
pk_o_acc_scale.x = o_acc_scale;
|
||||
@@ -928,7 +974,12 @@ struct BlockFmhaFwdV3Pipeline
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = kv_token_start + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -992,6 +1043,7 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_calc(xdl_SP_p01_reg_idx, gemm0);
|
||||
fmha_alu1(xdl_SP_p23_reg_idx);
|
||||
fmha_logits_trans(xdl_SP_p01_reg_idx);
|
||||
|
||||
Scheduler::schedule(cl_p, number<0>{});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -1066,6 +1118,7 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_calc(xdl_SP_p01_reg_idx, gemm0);
|
||||
fmha_alu1(xdl_SP_p23_reg_idx);
|
||||
fmha_logits_trans(xdl_SP_p01_reg_idx);
|
||||
|
||||
Scheduler::schedule(cl_p, number<1>{});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -1149,7 +1202,7 @@ struct BlockFmhaFwdV3Pipeline
|
||||
|
||||
// (3) mfma (Q*K0) + softmax
|
||||
gemm(number<0>{}, /*gemm_idx=*/number<0>{});
|
||||
|
||||
fmha_logits_trans(number<0>{});
|
||||
fmha_mask(number<0>{});
|
||||
/// TODO: find better way to map fmha_alu(0,96) call
|
||||
fmha_alu0(number<0>{});
|
||||
@@ -1244,13 +1297,18 @@ struct BlockFmhaFwdV3Pipeline
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
@@ -1268,6 +1326,9 @@ struct BlockFmhaFwdV3Pipeline
|
||||
identity{},
|
||||
mask,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user