mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
To prevent compiler issue, remove the elementwise function we have not used.
This commit is contained in:
@@ -385,12 +385,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
mask.x,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
descale_q * descale_k,
|
||||
descale_v};
|
||||
}();
|
||||
|
||||
@@ -60,24 +60,24 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
|
||||
|
||||
struct FmhaDefaultElementFunctions
|
||||
{
|
||||
using QElementFunction = ck_tile::identity;
|
||||
using KElementFunction = ck_tile::identity;
|
||||
using VElementFunction = ck_tile::identity;
|
||||
using BiasElementFunction = ck_tile::identity;
|
||||
using LSEElementFunction = ck_tile::identity;
|
||||
using SAccElementFunction = ck_tile::identity;
|
||||
// using QElementFunction = ck_tile::identity;
|
||||
// using KElementFunction = ck_tile::identity;
|
||||
// using VElementFunction = ck_tile::identity;
|
||||
// using BiasElementFunction = ck_tile::identity;
|
||||
// using LSEElementFunction = ck_tile::identity;
|
||||
// using SAccElementFunction = ck_tile::identity;
|
||||
using PComputeElementFunction = ck_tile::identity;
|
||||
using OAccElementFunction = ck_tile::identity;
|
||||
};
|
||||
|
||||
struct FmhaF8StaticQuantizationElementFunctions
|
||||
{
|
||||
using QElementFunction = ck_tile::identity;
|
||||
using KElementFunction = ck_tile::identity;
|
||||
using VElementFunction = ck_tile::identity;
|
||||
using BiasElementFunction = ck_tile::identity;
|
||||
using LSEElementFunction = ck_tile::identity;
|
||||
using SAccElementFunction = ck_tile::identity;
|
||||
// using QElementFunction = ck_tile::identity;
|
||||
// using KElementFunction = ck_tile::identity;
|
||||
// using VElementFunction = ck_tile::identity;
|
||||
// using BiasElementFunction = ck_tile::identity;
|
||||
// using LSEElementFunction = ck_tile::identity;
|
||||
// using SAccElementFunction = ck_tile::identity;
|
||||
using PComputeElementFunction = ck_tile::scale;
|
||||
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scale>;
|
||||
};
|
||||
@@ -316,12 +316,12 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t mask_y;
|
||||
ck_tile::index_t mask_x;
|
||||
typename ElementFunctions::QElementFunction q_element_func;
|
||||
typename ElementFunctions::KElementFunction k_element_func;
|
||||
typename ElementFunctions::VElementFunction v_element_func;
|
||||
typename ElementFunctions::BiasElementFunction bias_element_func;
|
||||
typename ElementFunctions::LSEElementFunction lse_element_func;
|
||||
typename ElementFunctions::SAccElementFunction s_acc_element_func;
|
||||
// typename ElementFunctions::QElementFunction q_element_func;
|
||||
// typename ElementFunctions::KElementFunction k_element_func;
|
||||
// typename ElementFunctions::VElementFunction v_element_func;
|
||||
// typename ElementFunctions::BiasElementFunction bias_element_func;
|
||||
// typename ElementFunctions::LSEElementFunction lse_element_func;
|
||||
// typename ElementFunctions::SAccElementFunction s_acc_element_func;
|
||||
typename ElementFunctions::PComputeElementFunction p_compute_element_func;
|
||||
typename ElementFunctions::OAccElementFunction o_acc_element_func;
|
||||
float descale_qk;
|
||||
@@ -362,12 +362,12 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
|
||||
args.nhead_stride_o,
|
||||
args.mask_y,
|
||||
args.mask_x,
|
||||
args.q_element_func,
|
||||
args.k_element_func,
|
||||
args.v_element_func,
|
||||
args.bias_element_func,
|
||||
args.lse_element_func,
|
||||
args.s_acc_element_func,
|
||||
// args.q_element_func,
|
||||
// args.k_element_func,
|
||||
// args.v_element_func,
|
||||
// args.bias_element_func,
|
||||
// args.lse_element_func,
|
||||
// args.s_acc_element_func,
|
||||
args.p_compute_element_func,
|
||||
args.o_acc_element_func,
|
||||
args.descale_qk,
|
||||
@@ -406,12 +406,12 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
|
||||
args.batch_stride_o,
|
||||
args.mask_y,
|
||||
args.mask_x,
|
||||
args.q_element_func,
|
||||
args.k_element_func,
|
||||
args.v_element_func,
|
||||
args.bias_element_func,
|
||||
args.lse_element_func,
|
||||
args.s_acc_element_func,
|
||||
// args.q_element_func,
|
||||
// args.k_element_func,
|
||||
// args.v_element_func,
|
||||
// args.bias_element_func,
|
||||
// args.lse_element_func,
|
||||
// args.s_acc_element_func,
|
||||
args.p_compute_element_func,
|
||||
args.o_acc_element_func,
|
||||
args.descale_qk,
|
||||
@@ -479,8 +479,5 @@ struct fmha_fwd_traits
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
template<typename FmhaFwdArgs_>
|
||||
float fmha_fwd(fmha_fwd_traits,
|
||||
FmhaFwdArgs_,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
template <typename FmhaFwdArgs_>
|
||||
float fmha_fwd(fmha_fwd_traits, FmhaFwdArgs_, const ck_tile::stream_config&);
|
||||
|
||||
@@ -90,12 +90,12 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_element_function_{F_idx} = ck_tile::FmhaElementFunctions<
|
||||
typename {F_element_func}::QElementFunction,
|
||||
typename {F_element_func}::KElementFunction,
|
||||
typename {F_element_func}::VElementFunction,
|
||||
typename {F_element_func}::BiasElementFunction,
|
||||
typename {F_element_func}::LSEElementFunction,
|
||||
typename {F_element_func}::SAccElementFunction,
|
||||
// typename {F_element_func}::QElementFunction,
|
||||
// typename {F_element_func}::KElementFunction,
|
||||
// typename {F_element_func}::VElementFunction,
|
||||
// typename {F_element_func}::BiasElementFunction,
|
||||
// typename {F_element_func}::LSEElementFunction,
|
||||
// typename {F_element_func}::SAccElementFunction,
|
||||
typename {F_element_func}::PComputeElementFunction,
|
||||
typename {F_element_func}::OAccElementFunction>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user