To prevent compiler issue, remove the elementwise function we have not used.

This commit is contained in:
rocking
2024-04-08 09:42:26 +00:00
parent 68153dea0b
commit f7d81364f3
5 changed files with 104 additions and 114 deletions

View File

@@ -385,12 +385,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.x,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
descale_q * descale_k,
descale_v};
}();

View File

@@ -60,24 +60,24 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
struct FmhaDefaultElementFunctions
{
using QElementFunction = ck_tile::identity;
using KElementFunction = ck_tile::identity;
using VElementFunction = ck_tile::identity;
using BiasElementFunction = ck_tile::identity;
using LSEElementFunction = ck_tile::identity;
using SAccElementFunction = ck_tile::identity;
// using QElementFunction = ck_tile::identity;
// using KElementFunction = ck_tile::identity;
// using VElementFunction = ck_tile::identity;
// using BiasElementFunction = ck_tile::identity;
// using LSEElementFunction = ck_tile::identity;
// using SAccElementFunction = ck_tile::identity;
using PComputeElementFunction = ck_tile::identity;
using OAccElementFunction = ck_tile::identity;
};
struct FmhaF8StaticQuantizationElementFunctions
{
using QElementFunction = ck_tile::identity;
using KElementFunction = ck_tile::identity;
using VElementFunction = ck_tile::identity;
using BiasElementFunction = ck_tile::identity;
using LSEElementFunction = ck_tile::identity;
using SAccElementFunction = ck_tile::identity;
// using QElementFunction = ck_tile::identity;
// using KElementFunction = ck_tile::identity;
// using VElementFunction = ck_tile::identity;
// using BiasElementFunction = ck_tile::identity;
// using LSEElementFunction = ck_tile::identity;
// using SAccElementFunction = ck_tile::identity;
using PComputeElementFunction = ck_tile::scale;
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scale>;
};
@@ -316,12 +316,12 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_o;
ck_tile::index_t mask_y;
ck_tile::index_t mask_x;
typename ElementFunctions::QElementFunction q_element_func;
typename ElementFunctions::KElementFunction k_element_func;
typename ElementFunctions::VElementFunction v_element_func;
typename ElementFunctions::BiasElementFunction bias_element_func;
typename ElementFunctions::LSEElementFunction lse_element_func;
typename ElementFunctions::SAccElementFunction s_acc_element_func;
// typename ElementFunctions::QElementFunction q_element_func;
// typename ElementFunctions::KElementFunction k_element_func;
// typename ElementFunctions::VElementFunction v_element_func;
// typename ElementFunctions::BiasElementFunction bias_element_func;
// typename ElementFunctions::LSEElementFunction lse_element_func;
// typename ElementFunctions::SAccElementFunction s_acc_element_func;
typename ElementFunctions::PComputeElementFunction p_compute_element_func;
typename ElementFunctions::OAccElementFunction o_acc_element_func;
float descale_qk;
@@ -362,12 +362,12 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
args.nhead_stride_o,
args.mask_y,
args.mask_x,
args.q_element_func,
args.k_element_func,
args.v_element_func,
args.bias_element_func,
args.lse_element_func,
args.s_acc_element_func,
// args.q_element_func,
// args.k_element_func,
// args.v_element_func,
// args.bias_element_func,
// args.lse_element_func,
// args.s_acc_element_func,
args.p_compute_element_func,
args.o_acc_element_func,
args.descale_qk,
@@ -406,12 +406,12 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
args.batch_stride_o,
args.mask_y,
args.mask_x,
args.q_element_func,
args.k_element_func,
args.v_element_func,
args.bias_element_func,
args.lse_element_func,
args.s_acc_element_func,
// args.q_element_func,
// args.k_element_func,
// args.v_element_func,
// args.bias_element_func,
// args.lse_element_func,
// args.s_acc_element_func,
args.p_compute_element_func,
args.o_acc_element_func,
args.descale_qk,
@@ -479,8 +479,5 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api
};
template<typename FmhaFwdArgs_>
float fmha_fwd(fmha_fwd_traits,
FmhaFwdArgs_,
const ck_tile::stream_config&);
template <typename FmhaFwdArgs_>
float fmha_fwd(fmha_fwd_traits, FmhaFwdArgs_, const ck_tile::stream_config&);

View File

@@ -90,12 +90,12 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
using fmha_mask_{F_idx} = {F_mask};
using fmha_element_function_{F_idx} = ck_tile::FmhaElementFunctions<
typename {F_element_func}::QElementFunction,
typename {F_element_func}::KElementFunction,
typename {F_element_func}::VElementFunction,
typename {F_element_func}::BiasElementFunction,
typename {F_element_func}::LSEElementFunction,
typename {F_element_func}::SAccElementFunction,
// typename {F_element_func}::QElementFunction,
// typename {F_element_func}::KElementFunction,
// typename {F_element_func}::VElementFunction,
// typename {F_element_func}::BiasElementFunction,
// typename {F_element_func}::LSEElementFunction,
// typename {F_element_func}::SAccElementFunction,
typename {F_element_func}::PComputeElementFunction,
typename {F_element_func}::OAccElementFunction>;