mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
To prevent compiler issue, remove the elementwise function we have not used.
This commit is contained in:
@@ -385,12 +385,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
mask.x,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
descale_q * descale_k,
|
||||
descale_v};
|
||||
}();
|
||||
|
||||
@@ -60,24 +60,24 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
|
||||
|
||||
struct FmhaDefaultElementFunctions
|
||||
{
|
||||
using QElementFunction = ck_tile::identity;
|
||||
using KElementFunction = ck_tile::identity;
|
||||
using VElementFunction = ck_tile::identity;
|
||||
using BiasElementFunction = ck_tile::identity;
|
||||
using LSEElementFunction = ck_tile::identity;
|
||||
using SAccElementFunction = ck_tile::identity;
|
||||
// using QElementFunction = ck_tile::identity;
|
||||
// using KElementFunction = ck_tile::identity;
|
||||
// using VElementFunction = ck_tile::identity;
|
||||
// using BiasElementFunction = ck_tile::identity;
|
||||
// using LSEElementFunction = ck_tile::identity;
|
||||
// using SAccElementFunction = ck_tile::identity;
|
||||
using PComputeElementFunction = ck_tile::identity;
|
||||
using OAccElementFunction = ck_tile::identity;
|
||||
};
|
||||
|
||||
struct FmhaF8StaticQuantizationElementFunctions
|
||||
{
|
||||
using QElementFunction = ck_tile::identity;
|
||||
using KElementFunction = ck_tile::identity;
|
||||
using VElementFunction = ck_tile::identity;
|
||||
using BiasElementFunction = ck_tile::identity;
|
||||
using LSEElementFunction = ck_tile::identity;
|
||||
using SAccElementFunction = ck_tile::identity;
|
||||
// using QElementFunction = ck_tile::identity;
|
||||
// using KElementFunction = ck_tile::identity;
|
||||
// using VElementFunction = ck_tile::identity;
|
||||
// using BiasElementFunction = ck_tile::identity;
|
||||
// using LSEElementFunction = ck_tile::identity;
|
||||
// using SAccElementFunction = ck_tile::identity;
|
||||
using PComputeElementFunction = ck_tile::scale;
|
||||
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scale>;
|
||||
};
|
||||
@@ -316,12 +316,12 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t mask_y;
|
||||
ck_tile::index_t mask_x;
|
||||
typename ElementFunctions::QElementFunction q_element_func;
|
||||
typename ElementFunctions::KElementFunction k_element_func;
|
||||
typename ElementFunctions::VElementFunction v_element_func;
|
||||
typename ElementFunctions::BiasElementFunction bias_element_func;
|
||||
typename ElementFunctions::LSEElementFunction lse_element_func;
|
||||
typename ElementFunctions::SAccElementFunction s_acc_element_func;
|
||||
// typename ElementFunctions::QElementFunction q_element_func;
|
||||
// typename ElementFunctions::KElementFunction k_element_func;
|
||||
// typename ElementFunctions::VElementFunction v_element_func;
|
||||
// typename ElementFunctions::BiasElementFunction bias_element_func;
|
||||
// typename ElementFunctions::LSEElementFunction lse_element_func;
|
||||
// typename ElementFunctions::SAccElementFunction s_acc_element_func;
|
||||
typename ElementFunctions::PComputeElementFunction p_compute_element_func;
|
||||
typename ElementFunctions::OAccElementFunction o_acc_element_func;
|
||||
float descale_qk;
|
||||
@@ -362,12 +362,12 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
|
||||
args.nhead_stride_o,
|
||||
args.mask_y,
|
||||
args.mask_x,
|
||||
args.q_element_func,
|
||||
args.k_element_func,
|
||||
args.v_element_func,
|
||||
args.bias_element_func,
|
||||
args.lse_element_func,
|
||||
args.s_acc_element_func,
|
||||
// args.q_element_func,
|
||||
// args.k_element_func,
|
||||
// args.v_element_func,
|
||||
// args.bias_element_func,
|
||||
// args.lse_element_func,
|
||||
// args.s_acc_element_func,
|
||||
args.p_compute_element_func,
|
||||
args.o_acc_element_func,
|
||||
args.descale_qk,
|
||||
@@ -406,12 +406,12 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
|
||||
args.batch_stride_o,
|
||||
args.mask_y,
|
||||
args.mask_x,
|
||||
args.q_element_func,
|
||||
args.k_element_func,
|
||||
args.v_element_func,
|
||||
args.bias_element_func,
|
||||
args.lse_element_func,
|
||||
args.s_acc_element_func,
|
||||
// args.q_element_func,
|
||||
// args.k_element_func,
|
||||
// args.v_element_func,
|
||||
// args.bias_element_func,
|
||||
// args.lse_element_func,
|
||||
// args.s_acc_element_func,
|
||||
args.p_compute_element_func,
|
||||
args.o_acc_element_func,
|
||||
args.descale_qk,
|
||||
@@ -479,8 +479,5 @@ struct fmha_fwd_traits
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
template<typename FmhaFwdArgs_>
|
||||
float fmha_fwd(fmha_fwd_traits,
|
||||
FmhaFwdArgs_,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
template <typename FmhaFwdArgs_>
|
||||
float fmha_fwd(fmha_fwd_traits, FmhaFwdArgs_, const ck_tile::stream_config&);
|
||||
|
||||
@@ -90,12 +90,12 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_element_function_{F_idx} = ck_tile::FmhaElementFunctions<
|
||||
typename {F_element_func}::QElementFunction,
|
||||
typename {F_element_func}::KElementFunction,
|
||||
typename {F_element_func}::VElementFunction,
|
||||
typename {F_element_func}::BiasElementFunction,
|
||||
typename {F_element_func}::LSEElementFunction,
|
||||
typename {F_element_func}::SAccElementFunction,
|
||||
// typename {F_element_func}::QElementFunction,
|
||||
// typename {F_element_func}::KElementFunction,
|
||||
// typename {F_element_func}::VElementFunction,
|
||||
// typename {F_element_func}::BiasElementFunction,
|
||||
// typename {F_element_func}::LSEElementFunction,
|
||||
// typename {F_element_func}::SAccElementFunction,
|
||||
typename {F_element_func}::PComputeElementFunction,
|
||||
typename {F_element_func}::OAccElementFunction>;
|
||||
|
||||
|
||||
@@ -35,24 +35,23 @@ struct FmhaFwdKernel
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
static constexpr bool kIsFp8 = FmhaPipeline::kIsFp8;
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
|
||||
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
|
||||
using ElementFunctions = typename FmhaPipeline::Problem::ElementFunctions;
|
||||
|
||||
using QElementFunction =
|
||||
ck_tile::remove_cvref_t<typename FmhaPipeline::Problem::ElementFunctions::QElementFunction>;
|
||||
using KElementFunction =
|
||||
ck_tile::remove_cvref_t<typename FmhaPipeline::Problem::ElementFunctions::KElementFunction>;
|
||||
using VElementFunction =
|
||||
ck_tile::remove_cvref_t<typename FmhaPipeline::Problem::ElementFunctions::VElementFunction>;
|
||||
using BiasElementFunction = ck_tile::remove_cvref_t<
|
||||
typename FmhaPipeline::Problem::ElementFunctions::BiasElementFunction>;
|
||||
using LSEElementFunction = ck_tile::remove_cvref_t<
|
||||
typename FmhaPipeline::Problem::ElementFunctions::LSEElementFunction>;
|
||||
using SAccElementFunction = ck_tile::remove_cvref_t<
|
||||
typename FmhaPipeline::Problem::ElementFunctions::SAccElementFunction>;
|
||||
using PComputeElementFunction = ck_tile::remove_cvref_t<
|
||||
typename FmhaPipeline::Problem::ElementFunctions::PComputeElementFunction>;
|
||||
using OAccElementFunction = ck_tile::remove_cvref_t<
|
||||
typename FmhaPipeline::Problem::ElementFunctions::OAccElementFunction>;
|
||||
// using QElementFunction = ck_tile::remove_cvref_t<typename
|
||||
// ElementFunctions::QElementFunction>; using KElementFunction =
|
||||
// ck_tile::remove_cvref_t<typename ElementFunctions::KElementFunction>; using VElementFunction
|
||||
// = ck_tile::remove_cvref_t<typename ElementFunctions::VElementFunction>; using
|
||||
// BiasElementFunction =
|
||||
// ck_tile::remove_cvref_t<typename ElementFunctions::BiasElementFunction>;
|
||||
// using LSEElementFunction =
|
||||
// ck_tile::remove_cvref_t<typename ElementFunctions::LSEElementFunction>;
|
||||
// using SAccElementFunction =
|
||||
// ck_tile::remove_cvref_t<typename ElementFunctions::SAccElementFunction>;
|
||||
using PComputeElementFunction =
|
||||
ck_tile::remove_cvref_t<typename ElementFunctions::PComputeElementFunction>;
|
||||
using OAccElementFunction =
|
||||
ck_tile::remove_cvref_t<typename ElementFunctions::OAccElementFunction>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
|
||||
@@ -140,12 +139,12 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
QElementFunction q_element_func;
|
||||
KElementFunction k_element_func;
|
||||
VElementFunction v_element_func;
|
||||
BiasElementFunction bias_element_func;
|
||||
LSEElementFunction lse_element_func;
|
||||
SAccElementFunction s_acc_element_func;
|
||||
// QElementFunction q_element_func;
|
||||
// KElementFunction k_element_func;
|
||||
// VElementFunction v_element_func;
|
||||
// BiasElementFunction bias_element_func;
|
||||
// LSEElementFunction lse_element_func;
|
||||
// SAccElementFunction s_acc_element_func;
|
||||
PComputeElementFunction p_compute_element_func;
|
||||
OAccElementFunction o_acc_element_func;
|
||||
};
|
||||
@@ -245,12 +244,12 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t mask_y,
|
||||
ck_tile::index_t mask_x,
|
||||
QElementFunction q_element_func,
|
||||
KElementFunction k_element_func,
|
||||
VElementFunction v_element_func,
|
||||
BiasElementFunction bias_element_func,
|
||||
LSEElementFunction lse_element_func,
|
||||
SAccElementFunction s_acc_element_func,
|
||||
// QElementFunction q_element_func,
|
||||
// KElementFunction k_element_func,
|
||||
// VElementFunction v_element_func,
|
||||
// BiasElementFunction bias_element_func,
|
||||
// LSEElementFunction lse_element_func,
|
||||
// SAccElementFunction s_acc_element_func,
|
||||
PComputeElementFunction p_compute_element_func,
|
||||
OAccElementFunction o_acc_element_func,
|
||||
float descale_qk,
|
||||
@@ -278,12 +277,12 @@ struct FmhaFwdKernel
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o,
|
||||
q_element_func,
|
||||
k_element_func,
|
||||
v_element_func,
|
||||
bias_element_func,
|
||||
lse_element_func,
|
||||
s_acc_element_func,
|
||||
// q_element_func,
|
||||
// k_element_func,
|
||||
// v_element_func,
|
||||
// bias_element_func,
|
||||
// lse_element_func,
|
||||
// s_acc_element_func,
|
||||
p_compute_element_func,
|
||||
o_acc_element_func}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
@@ -350,12 +349,12 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t mask_y,
|
||||
ck_tile::index_t mask_x,
|
||||
QElementFunction q_element_func,
|
||||
KElementFunction k_element_func,
|
||||
VElementFunction v_element_func,
|
||||
BiasElementFunction bias_element_func,
|
||||
LSEElementFunction lse_element_func,
|
||||
SAccElementFunction s_acc_element_func,
|
||||
// QElementFunction q_element_func,
|
||||
// KElementFunction k_element_func,
|
||||
// VElementFunction v_element_func,
|
||||
// BiasElementFunction bias_element_func,
|
||||
// LSEElementFunction lse_element_func,
|
||||
// SAccElementFunction s_acc_element_func,
|
||||
PComputeElementFunction p_compute_element_func,
|
||||
OAccElementFunction o_acc_element_func,
|
||||
float descale_qk,
|
||||
@@ -383,12 +382,12 @@ struct FmhaFwdKernel
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o,
|
||||
q_element_func,
|
||||
k_element_func,
|
||||
v_element_func,
|
||||
bias_element_func,
|
||||
lse_element_func,
|
||||
s_acc_element_func,
|
||||
// q_element_func,
|
||||
// k_element_func,
|
||||
// v_element_func,
|
||||
// bias_element_func,
|
||||
// lse_element_func,
|
||||
// s_acc_element_func,
|
||||
p_compute_element_func,
|
||||
o_acc_element_func}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
@@ -719,16 +718,16 @@ struct FmhaFwdKernel
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
kargs.q_element_func,
|
||||
identity{},
|
||||
k_dram_window,
|
||||
kargs.k_element_func,
|
||||
identity{},
|
||||
v_dram_window,
|
||||
kargs.v_element_func,
|
||||
identity{},
|
||||
bias_dram_window,
|
||||
kargs.bias_element_func,
|
||||
identity{},
|
||||
lse_dram_window,
|
||||
kargs.lse_element_func,
|
||||
kargs.s_acc_element_func,
|
||||
identity{},
|
||||
identity{},
|
||||
kargs.p_compute_element_func,
|
||||
kargs.o_acc_element_func,
|
||||
mask,
|
||||
|
||||
@@ -7,31 +7,31 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename QElementFunction_,
|
||||
template </*typename QElementFunction_,
|
||||
typename KElementFunction_,
|
||||
typename VElementFunction_,
|
||||
typename BiasElementFunction_,
|
||||
typename LSEElementFunction_,
|
||||
typename SAccElementFunction_,
|
||||
typename SAccElementFunction_,*/
|
||||
typename PComputeElementFunction_,
|
||||
typename OAccElementFunction_>
|
||||
struct FmhaElementFunctions
|
||||
{
|
||||
using QElementFunction = remove_cvref_t<QElementFunction_>;
|
||||
using KElementFunction = remove_cvref_t<KElementFunction_>;
|
||||
using VElementFunction = remove_cvref_t<VElementFunction_>;
|
||||
using BiasElementFunction = remove_cvref_t<BiasElementFunction_>;
|
||||
using LSEElementFunction = remove_cvref_t<LSEElementFunction_>;
|
||||
using SAccElementFunction = remove_cvref_t<SAccElementFunction_>;
|
||||
// using QElementFunction = remove_cvref_t<QElementFunction_>;
|
||||
// using KElementFunction = remove_cvref_t<KElementFunction_>;
|
||||
// using VElementFunction = remove_cvref_t<VElementFunction_>;
|
||||
// using BiasElementFunction = remove_cvref_t<BiasElementFunction_>;
|
||||
// using LSEElementFunction = remove_cvref_t<LSEElementFunction_>;
|
||||
// using SAccElementFunction = remove_cvref_t<SAccElementFunction_>;
|
||||
using PComputeElementFunction = remove_cvref_t<PComputeElementFunction_>;
|
||||
using OAccElementFunction = remove_cvref_t<OAccElementFunction_>;
|
||||
|
||||
QElementFunction q_element_func;
|
||||
KElementFunction k_element_func;
|
||||
VElementFunction v_element_func;
|
||||
BiasElementFunction bias_element_func;
|
||||
LSEElementFunction lse_element_func;
|
||||
SAccElementFunction s_acc_element_func;
|
||||
// QElementFunction q_element_func;
|
||||
// KElementFunction k_element_func;
|
||||
// VElementFunction v_element_func;
|
||||
// BiasElementFunction bias_element_func;
|
||||
// LSEElementFunction lse_element_func;
|
||||
// SAccElementFunction s_acc_element_func;
|
||||
PComputeElementFunction p_compute_element_func;
|
||||
OAccElementFunction o_acc_element_func;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user