[CK_TILE][FMHA] Integrate FAv2 & FAv3 (WIP) in the single fmha_fwd() API (#3153)

* Let fmha_fwd_v3() compatible with fmha_fwd()

* Decouple get_fwd_blobs() and FmhaFwdKernel

* Decouple compatibility checks from get_fwd_blobs()

* Extract product feature checks out from get_fwd_blobs()

* Remove duplicated code in factories and redundant checks

* Remove FmhaFwdKernel<>::GetName()

* Let FmhaFwdApiPool support pipelines with different mask_impl

* Add tile setting for fmha fwd v3 pipeline

* Add fwd v3 instances to tile_example_fmha_fwd manually

* Remove unused function import

* Undo irrelevant changes

* Remove fwd v3 instances from tile_example_fmha_fwd

* Finish fmha fwd v3 kernel instance codegen

* Fix formatting

* Remove unused F_idx attribute

* Add is_generic_attention_mask<> traits

* Add constraints to the fmha fwd v3 pipeline

* Unify traits & problem used for fmha fwd v3

* Unify kernel launch code for fmha fwd v2 & v3

* Unify kernel template selection logic

* Use same kernel codegen template for both v2 & v3

* Rename api() property as render() method

* Allow specifying filter for fmha fwd api pool

* Allow specifying function name when rendering api pool items

* Separate fmha fwd v3 kernel dispatching logic from v2

* Remove lambda assignment

* Add simple v2/v3 dispatch logic

* Stop generating empty if-clauses

Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them.

* Use "".join() to concatenate fmha fwd api string content

* Add more feature checks for fmha fwd v3 pipeline

* Check features before dispatch to fmha_fwd_v3()

* Add more feature checks for fmha_fwd_v3()

* Add missing filter call

* Use Tuple to reserve the dtype orders

* Fix wrong pipeline matching logic

* Add fmha fwd v3 group mode instances

* Add functor_transform<>

* Add type constraints to make_tile_window()

* Remove fmha fwd v3 example

* Fix wrong product(aiter mha_fwd()) config

* Fix wrong fmha fwd v2/v3 selection logic

* Fix formatting

* Add comment to warning v3 kernel users

* Fix wrong codegen logics

* Remove unnecessary param

* Fix format

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Po Yen Chen
2025-12-05 10:31:12 +08:00
committed by GitHub
parent d1193e8637
commit 05292b3604
22 changed files with 890 additions and 1449 deletions

View File

@@ -73,54 +73,6 @@ struct FmhaFwdKernel
#endif
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
// clang-format off
template <typename T1, typename T2 = T1> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::fp8_t, ck_tile::bf16_t> { static constexpr const char * name = "fp8bf16"; };
template <> struct t2s<ck_tile::fp8_t, ck_tile::fp32_t> { static constexpr const char * name = "fp8fp32"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape;
using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps;
using g0wt = typename bfs::Gemm0WarpTile;
using g1wt = typename bfs::Gemm1WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadSeqLenK) n += "sk";
if (kPadHeadDimQ) n += "d";
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType, ODataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) +
(QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr<QScaleEnum>::name)) + (kUseTrLoad ? "_trload" : "_ntrload");
#undef _SS_
#undef _TS_
// clang-format on
}
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct FmhaFwdEmptyKargs

View File

@@ -12,6 +12,8 @@
namespace ck_tile {
/// NOTICE: This kernel is a work in progress and is awaiting upcoming compiler fixes and
/// instruction scheduling optimizations.
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdV3Kernel
{
@@ -103,8 +105,8 @@ struct FmhaFwdV3Kernel
// Optional cumulative sequence length pointers for batch mode
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_k_ptr = nullptr; // [batch+1]
};
struct FmhaFwdGroupModeKargs
@@ -114,12 +116,13 @@ struct FmhaFwdV3Kernel
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_q_ptr;
const int32_t* seqlen_k_ptr;
// Optional cumulative padded sequence starts (including PAD tokens)
// Used solely to compute memory offsets when sequences are physically padded.
const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1]
const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1]
const int32_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const int32_t* cu_seqlen_k_ptr = nullptr; // [batch+1]
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
@@ -156,8 +159,8 @@ struct FmhaFwdV3Kernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt,
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -199,8 +202,8 @@ struct FmhaFwdV3Kernel
kargs.batch_stride_lse = batch_stride_lse;
}
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
return kargs;
}
@@ -213,6 +216,7 @@ struct FmhaFwdV3Kernel
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
@@ -232,8 +236,8 @@ struct FmhaFwdV3Kernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t remap_opt,
const void* seqstart_padded_q_ptr = nullptr,
const void* seqstart_padded_k_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -258,6 +262,7 @@ struct FmhaFwdV3Kernel
{}, // placeholder for lse
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
if constexpr(kHasMask)
@@ -273,30 +278,29 @@ struct FmhaFwdV3Kernel
kargs.nhead_stride_lse = nhead_stride_lse;
}
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
return kargs;
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v)
{
// TODO: this may need tuning
if constexpr(kHasMask)
if constexpr(kIsGroupMode)
{
return dim3(nhead_,
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
batch_size_);
return dim3(nhead,
batch_size,
ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1));
}
else
{
return dim3(nhead_,
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
batch_size_);
return dim3(nhead,
ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1),
batch_size);
}
}
@@ -344,13 +348,20 @@ struct FmhaFwdV3Kernel
// FmhaPipeline::kN1);
// assume that num_tile_n1 is always 1
if constexpr(kHasMask)
if constexpr(kIsGroupMode)
{
const index_t i_nhead = blockIdx.x;
const index_t i_block = blockIdx.y;
const index_t i_batch = blockIdx.z;
const index_t i_batch = blockIdx.y;
const index_t i_block = blockIdx.z;
return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
if constexpr(kHasMask)
{
return ck_tile::make_tuple(gridDim.z - 1 - i_block, 0, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
}
}
else
{
@@ -358,7 +369,14 @@ struct FmhaFwdV3Kernel
const index_t i_block = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
if constexpr(kHasMask)
{
return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
}
}
}
@@ -390,32 +408,36 @@ struct FmhaFwdV3Kernel
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
// Use seqstart_q_ptr and seqstart_k_ptr for physical starts
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
? kargs.seqstart_padded_q_ptr[i_batch]
: query_start_unpadded;
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
? kargs.seqstart_padded_k_ptr[i_batch]
: key_start_unpadded;
batch_offset_q = query_start_padded * kargs.stride_q;
batch_offset_k = key_start_padded * kargs.stride_k;
batch_offset_v = key_start_padded * kargs.stride_v;
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v;
if constexpr(kStoreLSE)
{
// LSE layout is [nhead, total_seqlen], index by unpadded start
batch_offset_lse = query_start_unpadded;
batch_offset_lse = query_start;
}
batch_offset_o = query_start_padded * kargs.stride_o;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
batch_offset_o = query_start * kargs.stride_o;
// real logical lengths (exclude PAD)
// Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
if(kargs.seqlen_q_ptr != nullptr)
{
kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
}
else if(kargs.cu_seqlen_q_ptr != nullptr)
{
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
else
{
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)
@@ -427,10 +449,14 @@ struct FmhaFwdV3Kernel
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
else if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
else
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
}
}
else
@@ -450,10 +476,10 @@ struct FmhaFwdV3Kernel
kargs.seqlen_q =
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
}
if(kargs.cu_seqlen_kv_ptr != nullptr)
if(kargs.cu_seqlen_k_ptr != nullptr)
{
kargs.seqlen_k =
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
}
}