mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK_TILE] fmha fwd splitkv optimization for decode (seqlen_q=1) (#1789)
* Update license year * Add initial code to override decode problem * Fix splitkv traits/args overriding error * Reshape and transpose lse for decode * Remove debug code * Prettify example code * Use better function name * Add kMergeNumHeadGroupsSeqLenQ flag Kernel user can use this switch to turn on/off optimization for some problem sizes * Add missing flag declarations * Default turn off kMergeNumHeadGroupsSeqLenQ in codegen * Group similar statements together * Remove assumption of seqlen_q=1 * Remove kMergeNumHeadGroupsSeqLenQ from splitkv combine kernel * Support kMergeNumHeadGroupsSeqLenQ=true in fmha splitkv kernel * Run kMergeNumHeadGroupsSeqLenQ=true kernels when need * Fix group mode block skip logics * Undo changes of normal fwd kernel * Update in GridSize() and using GridSize() for splitkv kernel (#1799) --------- Co-authored-by: Qianfeng <qianfeng.zhang@amd.com>
This commit is contained in:
@@ -48,8 +48,8 @@ using fmha_dtype_{F_idx} = {F_dtype};
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
namespace {{
|
||||
template <bool kHasUnevenSplits>
|
||||
struct kernel_runner {{
|
||||
template <bool kHasUnevenSplits, bool kMergeNumHeadGroupsSeqLenQ = false>
|
||||
struct instance {{
|
||||
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
|
||||
@@ -64,11 +64,12 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_bias},
|
||||
false,
|
||||
/*kHasBiasGrad=*/false,
|
||||
{F_lse},
|
||||
{F_squant},
|
||||
{F_pagedkv},
|
||||
kHasUnevenSplits,
|
||||
kMergeNumHeadGroupsSeqLenQ,
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
|
||||
@@ -115,28 +116,50 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
|
||||
namespace {{
|
||||
template <bool kHasUnevenSplits>
|
||||
void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{
|
||||
if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
|
||||
&& (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|
||||
|| std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
|
||||
if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
|
||||
instance<kHasUnevenSplits, /*kMergeNumHeadGroupsSeqLenQ=*/true>::run(s, a);
|
||||
}} else {{
|
||||
instance<kHasUnevenSplits>::run(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
instance<kHasUnevenSplits>::run(s, a);
|
||||
}}
|
||||
}}
|
||||
}} // anonymous namespace
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
template<>
|
||||
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
if constexpr({F_mode} == false) {{ // batch mode
|
||||
// we don't check every seqlen_k values for kvcache
|
||||
if (a.seqlen_k_ptr != nullptr) {{
|
||||
kernel_runner<true>::run(s, a);
|
||||
run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
// make sure F_bn0 is divisible by F_bk1
|
||||
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
|
||||
kernel_runner<false>::run(s, a);
|
||||
run_instance</*kHasUnevenSplits=*/false>(s, a);
|
||||
}} else {{
|
||||
kernel_runner<true>::run(s, a);
|
||||
run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
kernel_runner<true>::run(s, a);
|
||||
run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
}}
|
||||
}}
|
||||
|
||||
template<>
|
||||
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = kernel_runner<true>::fmha_kernel; /// FIXME: choose real kernel type
|
||||
using k_ = instance<true>::fmha_kernel; /// FIXME: choose real kernel type
|
||||
return k_::GetName();
|
||||
}}
|
||||
"""
|
||||
@@ -146,7 +169,7 @@ using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
namespace {{
|
||||
template <ck_tile::index_t kLogMaxSplits>
|
||||
struct kernel_runner {{
|
||||
struct instance {{
|
||||
using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad},
|
||||
{F_dvpad},
|
||||
{F_lse},
|
||||
@@ -196,22 +219,22 @@ template<>
|
||||
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
if (a.num_splits <= 8) {{
|
||||
kernel_runner<3>::run(s, a);
|
||||
instance<3>::run(s, a);
|
||||
}} else if (a.num_splits <= 16) {{
|
||||
kernel_runner<4>::run(s, a);
|
||||
instance<4>::run(s, a);
|
||||
}} else if (a.num_splits <= 32) {{
|
||||
kernel_runner<5>::run(s, a);
|
||||
instance<5>::run(s, a);
|
||||
}} else if (a.num_splits <= 64) {{
|
||||
kernel_runner<6>::run(s, a);
|
||||
instance<6>::run(s, a);
|
||||
}} else if (a.num_splits <= 128) {{
|
||||
kernel_runner<7>::run(s, a);
|
||||
instance<7>::run(s, a);
|
||||
}}
|
||||
}}
|
||||
|
||||
template<>
|
||||
std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = kernel_runner<6>::fmha_kernel; /// FIXME: choose real kernel type
|
||||
using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type
|
||||
return k_::GetName();
|
||||
}}
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user