To prevent compiler issue, remove the elementwise function we have not used.

This commit is contained in:
rocking
2024-04-08 09:42:26 +00:00
parent 68153dea0b
commit f7d81364f3
5 changed files with 104 additions and 114 deletions

View File

@@ -385,12 +385,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.x,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
descale_q * descale_k,
descale_v};
}();

View File

@@ -60,24 +60,24 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
struct FmhaDefaultElementFunctions
{
using QElementFunction = ck_tile::identity;
using KElementFunction = ck_tile::identity;
using VElementFunction = ck_tile::identity;
using BiasElementFunction = ck_tile::identity;
using LSEElementFunction = ck_tile::identity;
using SAccElementFunction = ck_tile::identity;
// using QElementFunction = ck_tile::identity;
// using KElementFunction = ck_tile::identity;
// using VElementFunction = ck_tile::identity;
// using BiasElementFunction = ck_tile::identity;
// using LSEElementFunction = ck_tile::identity;
// using SAccElementFunction = ck_tile::identity;
using PComputeElementFunction = ck_tile::identity;
using OAccElementFunction = ck_tile::identity;
};
struct FmhaF8StaticQuantizationElementFunctions
{
using QElementFunction = ck_tile::identity;
using KElementFunction = ck_tile::identity;
using VElementFunction = ck_tile::identity;
using BiasElementFunction = ck_tile::identity;
using LSEElementFunction = ck_tile::identity;
using SAccElementFunction = ck_tile::identity;
// using QElementFunction = ck_tile::identity;
// using KElementFunction = ck_tile::identity;
// using VElementFunction = ck_tile::identity;
// using BiasElementFunction = ck_tile::identity;
// using LSEElementFunction = ck_tile::identity;
// using SAccElementFunction = ck_tile::identity;
using PComputeElementFunction = ck_tile::scale;
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scale>;
};
@@ -316,12 +316,12 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_o;
ck_tile::index_t mask_y;
ck_tile::index_t mask_x;
typename ElementFunctions::QElementFunction q_element_func;
typename ElementFunctions::KElementFunction k_element_func;
typename ElementFunctions::VElementFunction v_element_func;
typename ElementFunctions::BiasElementFunction bias_element_func;
typename ElementFunctions::LSEElementFunction lse_element_func;
typename ElementFunctions::SAccElementFunction s_acc_element_func;
// typename ElementFunctions::QElementFunction q_element_func;
// typename ElementFunctions::KElementFunction k_element_func;
// typename ElementFunctions::VElementFunction v_element_func;
// typename ElementFunctions::BiasElementFunction bias_element_func;
// typename ElementFunctions::LSEElementFunction lse_element_func;
// typename ElementFunctions::SAccElementFunction s_acc_element_func;
typename ElementFunctions::PComputeElementFunction p_compute_element_func;
typename ElementFunctions::OAccElementFunction o_acc_element_func;
float descale_qk;
@@ -362,12 +362,12 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
args.nhead_stride_o,
args.mask_y,
args.mask_x,
args.q_element_func,
args.k_element_func,
args.v_element_func,
args.bias_element_func,
args.lse_element_func,
args.s_acc_element_func,
// args.q_element_func,
// args.k_element_func,
// args.v_element_func,
// args.bias_element_func,
// args.lse_element_func,
// args.s_acc_element_func,
args.p_compute_element_func,
args.o_acc_element_func,
args.descale_qk,
@@ -406,12 +406,12 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
args.batch_stride_o,
args.mask_y,
args.mask_x,
args.q_element_func,
args.k_element_func,
args.v_element_func,
args.bias_element_func,
args.lse_element_func,
args.s_acc_element_func,
// args.q_element_func,
// args.k_element_func,
// args.v_element_func,
// args.bias_element_func,
// args.lse_element_func,
// args.s_acc_element_func,
args.p_compute_element_func,
args.o_acc_element_func,
args.descale_qk,
@@ -479,8 +479,5 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api
};
template<typename FmhaFwdArgs_>
float fmha_fwd(fmha_fwd_traits,
FmhaFwdArgs_,
const ck_tile::stream_config&);
template <typename FmhaFwdArgs_>
float fmha_fwd(fmha_fwd_traits, FmhaFwdArgs_, const ck_tile::stream_config&);

View File

@@ -90,12 +90,12 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
using fmha_mask_{F_idx} = {F_mask};
using fmha_element_function_{F_idx} = ck_tile::FmhaElementFunctions<
typename {F_element_func}::QElementFunction,
typename {F_element_func}::KElementFunction,
typename {F_element_func}::VElementFunction,
typename {F_element_func}::BiasElementFunction,
typename {F_element_func}::LSEElementFunction,
typename {F_element_func}::SAccElementFunction,
// typename {F_element_func}::QElementFunction,
// typename {F_element_func}::KElementFunction,
// typename {F_element_func}::VElementFunction,
// typename {F_element_func}::BiasElementFunction,
// typename {F_element_func}::LSEElementFunction,
// typename {F_element_func}::SAccElementFunction,
typename {F_element_func}::PComputeElementFunction,
typename {F_element_func}::OAccElementFunction>;

View File

@@ -35,24 +35,23 @@ struct FmhaFwdKernel
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
static constexpr bool kIsFp8 = FmhaPipeline::kIsFp8;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
using ElementFunctions = typename FmhaPipeline::Problem::ElementFunctions;
using QElementFunction =
ck_tile::remove_cvref_t<typename FmhaPipeline::Problem::ElementFunctions::QElementFunction>;
using KElementFunction =
ck_tile::remove_cvref_t<typename FmhaPipeline::Problem::ElementFunctions::KElementFunction>;
using VElementFunction =
ck_tile::remove_cvref_t<typename FmhaPipeline::Problem::ElementFunctions::VElementFunction>;
using BiasElementFunction = ck_tile::remove_cvref_t<
typename FmhaPipeline::Problem::ElementFunctions::BiasElementFunction>;
using LSEElementFunction = ck_tile::remove_cvref_t<
typename FmhaPipeline::Problem::ElementFunctions::LSEElementFunction>;
using SAccElementFunction = ck_tile::remove_cvref_t<
typename FmhaPipeline::Problem::ElementFunctions::SAccElementFunction>;
using PComputeElementFunction = ck_tile::remove_cvref_t<
typename FmhaPipeline::Problem::ElementFunctions::PComputeElementFunction>;
using OAccElementFunction = ck_tile::remove_cvref_t<
typename FmhaPipeline::Problem::ElementFunctions::OAccElementFunction>;
// using QElementFunction = ck_tile::remove_cvref_t<typename
// ElementFunctions::QElementFunction>; using KElementFunction =
// ck_tile::remove_cvref_t<typename ElementFunctions::KElementFunction>; using VElementFunction
// = ck_tile::remove_cvref_t<typename ElementFunctions::VElementFunction>; using
// BiasElementFunction =
// ck_tile::remove_cvref_t<typename ElementFunctions::BiasElementFunction>;
// using LSEElementFunction =
// ck_tile::remove_cvref_t<typename ElementFunctions::LSEElementFunction>;
// using SAccElementFunction =
// ck_tile::remove_cvref_t<typename ElementFunctions::SAccElementFunction>;
using PComputeElementFunction =
ck_tile::remove_cvref_t<typename ElementFunctions::PComputeElementFunction>;
using OAccElementFunction =
ck_tile::remove_cvref_t<typename ElementFunctions::OAccElementFunction>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
@@ -140,12 +139,12 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
QElementFunction q_element_func;
KElementFunction k_element_func;
VElementFunction v_element_func;
BiasElementFunction bias_element_func;
LSEElementFunction lse_element_func;
SAccElementFunction s_acc_element_func;
// QElementFunction q_element_func;
// KElementFunction k_element_func;
// VElementFunction v_element_func;
// BiasElementFunction bias_element_func;
// LSEElementFunction lse_element_func;
// SAccElementFunction s_acc_element_func;
PComputeElementFunction p_compute_element_func;
OAccElementFunction o_acc_element_func;
};
@@ -245,12 +244,12 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t mask_y,
ck_tile::index_t mask_x,
QElementFunction q_element_func,
KElementFunction k_element_func,
VElementFunction v_element_func,
BiasElementFunction bias_element_func,
LSEElementFunction lse_element_func,
SAccElementFunction s_acc_element_func,
// QElementFunction q_element_func,
// KElementFunction k_element_func,
// VElementFunction v_element_func,
// BiasElementFunction bias_element_func,
// LSEElementFunction lse_element_func,
// SAccElementFunction s_acc_element_func,
PComputeElementFunction p_compute_element_func,
OAccElementFunction o_acc_element_func,
float descale_qk,
@@ -278,12 +277,12 @@ struct FmhaFwdKernel
nhead_stride_k,
nhead_stride_v,
nhead_stride_o,
q_element_func,
k_element_func,
v_element_func,
bias_element_func,
lse_element_func,
s_acc_element_func,
// q_element_func,
// k_element_func,
// v_element_func,
// bias_element_func,
// lse_element_func,
// s_acc_element_func,
p_compute_element_func,
o_acc_element_func}, // args for common karg
{}, // placeholder for bias
@@ -350,12 +349,12 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t mask_y,
ck_tile::index_t mask_x,
QElementFunction q_element_func,
KElementFunction k_element_func,
VElementFunction v_element_func,
BiasElementFunction bias_element_func,
LSEElementFunction lse_element_func,
SAccElementFunction s_acc_element_func,
// QElementFunction q_element_func,
// KElementFunction k_element_func,
// VElementFunction v_element_func,
// BiasElementFunction bias_element_func,
// LSEElementFunction lse_element_func,
// SAccElementFunction s_acc_element_func,
PComputeElementFunction p_compute_element_func,
OAccElementFunction o_acc_element_func,
float descale_qk,
@@ -383,12 +382,12 @@ struct FmhaFwdKernel
nhead_stride_k,
nhead_stride_v,
nhead_stride_o,
q_element_func,
k_element_func,
v_element_func,
bias_element_func,
lse_element_func,
s_acc_element_func,
// q_element_func,
// k_element_func,
// v_element_func,
// bias_element_func,
// lse_element_func,
// s_acc_element_func,
p_compute_element_func,
o_acc_element_func}, // args for common karg
{}, // placeholder for bias
@@ -719,16 +718,16 @@ struct FmhaFwdKernel
else
{
return FmhaPipeline{}(q_dram_window,
kargs.q_element_func,
identity{},
k_dram_window,
kargs.k_element_func,
identity{},
v_dram_window,
kargs.v_element_func,
identity{},
bias_dram_window,
kargs.bias_element_func,
identity{},
lse_dram_window,
kargs.lse_element_func,
kargs.s_acc_element_func,
identity{},
identity{},
kargs.p_compute_element_func,
kargs.o_acc_element_func,
mask,

View File

@@ -7,31 +7,31 @@
namespace ck_tile {
template <typename QElementFunction_,
template </*typename QElementFunction_,
typename KElementFunction_,
typename VElementFunction_,
typename BiasElementFunction_,
typename LSEElementFunction_,
typename SAccElementFunction_,
typename SAccElementFunction_,*/
typename PComputeElementFunction_,
typename OAccElementFunction_>
struct FmhaElementFunctions
{
using QElementFunction = remove_cvref_t<QElementFunction_>;
using KElementFunction = remove_cvref_t<KElementFunction_>;
using VElementFunction = remove_cvref_t<VElementFunction_>;
using BiasElementFunction = remove_cvref_t<BiasElementFunction_>;
using LSEElementFunction = remove_cvref_t<LSEElementFunction_>;
using SAccElementFunction = remove_cvref_t<SAccElementFunction_>;
// using QElementFunction = remove_cvref_t<QElementFunction_>;
// using KElementFunction = remove_cvref_t<KElementFunction_>;
// using VElementFunction = remove_cvref_t<VElementFunction_>;
// using BiasElementFunction = remove_cvref_t<BiasElementFunction_>;
// using LSEElementFunction = remove_cvref_t<LSEElementFunction_>;
// using SAccElementFunction = remove_cvref_t<SAccElementFunction_>;
using PComputeElementFunction = remove_cvref_t<PComputeElementFunction_>;
using OAccElementFunction = remove_cvref_t<OAccElementFunction_>;
QElementFunction q_element_func;
KElementFunction k_element_func;
VElementFunction v_element_func;
BiasElementFunction bias_element_func;
LSEElementFunction lse_element_func;
SAccElementFunction s_acc_element_func;
// QElementFunction q_element_func;
// KElementFunction k_element_func;
// VElementFunction v_element_func;
// BiasElementFunction bias_element_func;
// LSEElementFunction lse_element_func;
// SAccElementFunction s_acc_element_func;
PComputeElementFunction p_compute_element_func;
OAccElementFunction o_acc_element_func;
};