diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 0cffb2642c..6b3464d226 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -211,11 +211,10 @@ float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream const bool can_dispatch_v3 = (device_name.compare(0, 6, "gfx950") == 0) and (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and - traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and - (traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and - (not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and - (not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and - (args.hdim_v == 128); + traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and + (not traits.has_lse) and (not traits.has_dropout) and + (traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and + (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128); if ({F_is_v3_enabled} and can_dispatch_v3) {{ return fmha_fwd_v3(traits, args, config); }} else {{ @@ -1082,9 +1081,9 @@ class KernelComponentFactoryGfx950( # qr_async_trload_v3 only supports hdim=hdim_v=128 for now if (hdim, hdim_v) == (128, 128): # qr_async_trload_v3 only supports (generic) causal mask - for mask in ["no", "causal"]: + for logits, mask in itertools.product(["t", "f"], ["no", "causal"]): pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", - F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip + F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip return pipelines diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 60ba334fc0..6a1c620577 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -728,6 +728,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -758,6 +759,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index f981c54bd8..6fe1de634d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" #include #include @@ -30,14 +31,16 @@ struct FmhaFwdV3Kernel using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; - static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; - static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; - using FmhaMask = ck_tile::remove_cvref_t; + using AttentionVariant = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; template // to avoid duplicated base class prblem, introduce an template @@ -93,10 +96,33 @@ struct FmhaFwdV3Kernel ck_tile::index_t batch_stride_lse = 0; }; + struct FmhaFwdLogitsSoftCapKargs + { + FmhaFwdLogitsSoftCapKargs() = default; + + void init_logits_soft_cap(float logits_soft_cap_) + { + if(0 < logits_soft_cap_) + { + logits_soft_cap = logits_soft_cap_; + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap = 0.f; + logits_soft_cap_rcp = 0.f; + } + } + + float logits_soft_cap; + float logits_soft_cap_rcp; + }; + struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -112,7 +138,8 @@ struct FmhaFwdV3Kernel struct FmhaFwdGroupModeKargs : FmhaFwdCommonKargs, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -127,6 +154,13 @@ struct FmhaFwdV3Kernel using Kargs = std::conditional_t; + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + template CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, @@ -141,6 +175,7 @@ struct FmhaFwdV3Kernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -183,6 +218,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for logits_soft_cap batch_stride_q, batch_stride_k, batch_stride_v, @@ -201,6 +237,10 @@ struct FmhaFwdV3Kernel kargs.nhead_stride_lse = nhead_stride_lse; kargs.batch_stride_lse = batch_stride_lse; } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); @@ -223,6 +263,7 @@ struct FmhaFwdV3Kernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -260,6 +301,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_q_ptr), @@ -277,6 +319,10 @@ struct FmhaFwdV3Kernel kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); @@ -594,6 +640,21 @@ struct FmhaFwdV3Kernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + auto o_acc_tile = [&]() { return FmhaPipeline{}(q_dram_window, k_dram_window, @@ -601,6 +662,9 @@ struct FmhaFwdV3Kernel lse_dram_window, mask, kargs.scale_s, + variant, + variant_params, + block_indices, smem_ptr); }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 68ec349694..c25f57632f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -264,6 +264,7 @@ struct BlockFmhaFwdV3Pipeline using PDataType = ck_tile::remove_cvref_t; using OaccDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; + using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; static_assert(is_generic_attention_mask_v); @@ -298,8 +299,7 @@ struct BlockFmhaFwdV3Pipeline static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; - static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS && - !kStoreLSE && !kHasDropout && + static_assert((BiasEnum == BlockAttentionBiasEnum::NO_BIAS && !kStoreLSE && !kHasDropout && (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && !kSkipMinSeqlenQ), "enable unsupported features"); @@ -401,7 +401,9 @@ struct BlockFmhaFwdV3Pipeline typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, - typename OAccElementFunction> + typename OAccElementFunction, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -415,6 +417,9 @@ struct BlockFmhaFwdV3Pipeline const OAccElementFunction& o_acc_element_func, FmhaMask mask, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr) const { using namespace ck_tile; @@ -721,6 +726,22 @@ struct BlockFmhaFwdV3Pipeline /// TODO: remove the sp_delta and use sp_compute directly statically_indexed_array{}).sp_compute), 2> sp_delta; + auto fmha_logits_trans = [&](auto sp_reg_idx) { + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = [&variant, &variant_params, &block_indices]( + auto& logits) { + logits = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, logits), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; + + tile_elementwise_inout(apply_logits_transform, sp(sp_reg_idx).sp_compute); + } + }; + auto fmha_alu0 = [&](auto sp_reg_idx) { m_old = m; // m{j-1} static_assert(m.thread_buf_.size() == 1, @@ -746,9 +767,17 @@ struct BlockFmhaFwdV3Pipeline std::decay_t::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv( - sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx)); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + if constexpr(kHasLogitsSoftCap) + { + sp_delta(sp_reg_idx)(i_j_idx) = + sp(sp_reg_idx).sp_compute(i_j_idx) - m(i_j_idx); + } + else + { + sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv( + sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx)); + } }); }); /// TODO: move some fmha_alu1() code here if necessary @@ -793,8 +822,16 @@ struct BlockFmhaFwdV3Pipeline constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx])); - + const auto tmp = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::exp2(m_old[i_idx] - m[i_idx]); + } + else + { + return ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx])); + } + }(); l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]); }); @@ -880,7 +917,16 @@ struct BlockFmhaFwdV3Pipeline }; auto fmha_alu_D_upd = [&] { - o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0])); + o_acc_scale = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::exp2(m_old.thread_buf_[0] - m.thread_buf_[0]); + } + else + { + return ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0])); + } + }(); fp32x2_t pk_o_acc_scale; pk_o_acc_scale.x = o_acc_scale; @@ -928,7 +974,12 @@ struct BlockFmhaFwdV3Pipeline const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = kv_token_start + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); }); } } @@ -992,6 +1043,7 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p01_reg_idx, gemm0); fmha_alu1(xdl_SP_p23_reg_idx); + fmha_logits_trans(xdl_SP_p01_reg_idx); Scheduler::schedule(cl_p, number<0>{}); __builtin_amdgcn_sched_barrier(0); @@ -1066,6 +1118,7 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p01_reg_idx, gemm0); fmha_alu1(xdl_SP_p23_reg_idx); + fmha_logits_trans(xdl_SP_p01_reg_idx); Scheduler::schedule(cl_p, number<1>{}); __builtin_amdgcn_sched_barrier(0); @@ -1149,7 +1202,7 @@ struct BlockFmhaFwdV3Pipeline // (3) mfma (Q*K0) + softmax gemm(number<0>{}, /*gemm_idx=*/number<0>{}); - + fmha_logits_trans(number<0>{}); fmha_mask(number<0>{}); /// TODO: find better way to map fmha_alu(0,96) call fmha_alu0(number<0>{}); @@ -1244,13 +1297,18 @@ struct BlockFmhaFwdV3Pipeline template + typename LSEDramBlockWindowTmp, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr) const { using namespace ck_tile; @@ -1268,6 +1326,9 @@ struct BlockFmhaFwdV3Pipeline identity{}, mask, scale_s, + variant, + variant_params, + block_indices, smem_ptr); } };