mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CK_TILE][FMHA] Integrate FAv2 & FAv3 (WIP) in the single fmha_fwd() API (#3153)
* Let fmha_fwd_v3() compatible with fmha_fwd() * Decouple get_fwd_blobs() and FmhaFwdKernel * Decouple compatibility checks from get_fwd_blobs() * Extract product feature checks out from get_fwd_blobs() * Remove duplicated code in factories and redundant checks * Remove FmhaFwdKernel<>::GetName() * Let FmhaFwdApiPool support pipelines with different mask_impl * Add tile setting for fmha fwd v3 pipeline * Add fwd v3 instances to tile_example_fmha_fwd manually * Remove unused function import * Undo irrelevant changes * Remove fwd v3 instances from tile_example_fmha_fwd * Finish fmha fwd v3 kernel instance codegen * Fix formatting * Remove unused F_idx attribute * Add is_generic_attention_mask<> traits * Add constraints to the fmha fwd v3 pipeline * Unify traits & problem used for fmha fwd v3 * Unify kernel launch code for fmha fwd v2 & v3 * Unify kernel template selection logic * Use same kernel codegen template for both v2 & v3 * Rename api() property as render() method * Allow specifying filter for fmha fwd api pool * Allow specifying function name when rendering api pool items * Separate fmha fwd v3 kernel dispatching logic from v2 * Remove lambda assignment * Add simple v2/v3 dispatch logic * Stop generating empty if-clauses Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them. * Use "".join() to concatenate fmha fwd api string content * Add more feature checks for fmha fwd v3 pipeline * Check features before dispatch to fmha_fwd_v3() * Add more feature checks for fmha_fwd_v3() * Add missing filter call * Use Tuple to reserve the dtype orders * Fix wrong pipeline matching logic * Add fmha fwd v3 group mode instances * Add functor_transform<> * Add type constraints to make_tile_window() * Remove fmha fwd v3 example * Fix wrong product(aiter mha_fwd()) config * Fix wrong fmha fwd v2/v3 selection logic * Fix formatting * Add comment to warning v3 kernel users * Fix wrong codegen logics * Remove unnecessary param * Fix format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -686,6 +686,100 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
|
||||
{
|
||||
/// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly
|
||||
/// maximizes the kernel's performance.
|
||||
int remap_opt = 2;
|
||||
if(args.mask_type != static_cast<int>(mask_enum::no_mask) &&
|
||||
((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q)))
|
||||
{
|
||||
if(65536 <= args.seqlen_q)
|
||||
{
|
||||
remap_opt = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
remap_opt = 1;
|
||||
}
|
||||
}
|
||||
|
||||
auto kargs = [&] {
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
nullptr, // lse_ptr
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
0, // nhead_stride_lse
|
||||
args.nhead_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
remap_opt,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
nullptr, // lse_ptr
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
0, // nhead_stride_lse
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
0, // batch_stride_lse
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
remap_opt,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user