[CK_TILE][FMHA] Integrate FAv2 & FAv3 (WIP) in the single fmha_fwd() API (#3153)

* Let fmha_fwd_v3() compatible with fmha_fwd()

* Decouple get_fwd_blobs() and FmhaFwdKernel

* Decouple compatibility checks from get_fwd_blobs()

* Extract product feature checks out from get_fwd_blobs()

* Remove duplicated code in factories and redundant checks

* Remove FmhaFwdKernel<>::GetName()

* Let FmhaFwdApiPool support pipelines with different mask_impl

* Add tile setting for fmha fwd v3 pipeline

* Add fwd v3 instances to tile_example_fmha_fwd manually

* Remove unused function import

* Undo irrelevant changes

* Remove fwd v3 instances from tile_example_fmha_fwd

* Finish fmha fwd v3 kernel instance codegen

* Fix formatting

* Remove unused F_idx attribute

* Add is_generic_attention_mask<> traits

* Add constraints to the fmha fwd v3 pipeline

* Unify traits & problem used for fmha fwd v3

* Unify kernel launch code for fmha fwd v2 & v3

* Unify kernel template selection logic

* Use same kernel codegen template for both v2 & v3

* Rename api() property as render() method

* Allow specifying filter for fmha fwd api pool

* Allow specifying function name when rendering api pool items

* Separate fmha fwd v3 kernel dispatching logic from v2

* Remove lambda assignment

* Add simple v2/v3 dispatch logic

* Stop generating empty if-clauses

Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them.

* Use "".join() to concatenate fmha fwd api string content

* Add more feature checks for fmha fwd v3 pipeline

* Check features before dispatch to fmha_fwd_v3()

* Add more feature checks for fmha_fwd_v3()

* Add missing filter call

* Use Tuple to reserve the dtype orders

* Fix wrong pipeline matching logic

* Add fmha fwd v3 group mode instances

* Add functor_transform<>

* Add type constraints to make_tile_window()

* Remove fmha fwd v3 example

* Fix wrong product(aiter mha_fwd()) config

* Fix wrong fmha fwd v2/v3 selection logic

* Fix formatting

* Add comment to warning v3 kernel users

* Fix wrong codegen logics

* Remove unnecessary param

* Fix format

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Po Yen Chen
2025-12-05 10:31:12 +08:00
committed by GitHub
parent d1193e8637
commit 05292b3604
22 changed files with 890 additions and 1449 deletions

View File

@@ -4,6 +4,8 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
@@ -246,6 +248,8 @@ CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
}
} // namespace detail
/// NOTICE: This pipeline is a work in progress and is awaiting upcoming compiler fixes and
/// instruction scheduling optimizations.
template <typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
struct BlockFmhaFwdV3Pipeline
{
@@ -261,12 +265,16 @@ struct BlockFmhaFwdV3Pipeline
using OaccDataType = ck_tile::remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = ck_tile::remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = ck_tile::remove_cvref_t<typename Problem::FmhaMask>;
static_assert(is_generic_attention_mask_v<FmhaMask>);
static_assert(std::is_same_v<SaccDataType, SMPLComputeDataType>,
"we will the same dist tensor 'sp_compute' for both gemm0 & softmax");
using BlockFmhaShape = ck_tile::remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
@@ -277,14 +285,24 @@ struct BlockFmhaFwdV3Pipeline
static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static_assert(kQKHeaddim == 128 && kSubQKHeaddim == 128, "only supports hdim=hdim_v=128");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ;
static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS &&
!kStoreLSE && !kHasDropout &&
(QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) &&
!kSkipMinSeqlenQ),
"enable unsupported features");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this

View File

@@ -12,6 +12,7 @@ enum class BlockFmhaPipelineEnum
QRKSVS_ASYNC,
QSKSVS,
QRKSVS_ASYNC_TRLOAD,
QRKSVS_ASYNC_TRLOAD_V3,
};
template <BlockFmhaPipelineEnum>

View File

@@ -264,47 +264,4 @@ struct BlockFmhaFwdAppendKVPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename SaccDataType_,
typename SMPLComputeDataType_,
typename LSEDataType_,
typename PDataType_,
typename OaccDataType_,
typename ODataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename FmhaMask_,
typename Traits_>
struct BlockFmhaFwdV3PipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using SaccDataType = remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using PDataType = remove_cvref_t<PDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile

View File

@@ -166,20 +166,4 @@ struct TileFmhaBwdConvertQGradTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
bool kStoreLSE_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdV3Traits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile