Remove not-in-use elementwise function kargs

This commit is contained in:
Po Yen Chen
2024-04-09 06:03:35 +00:00
parent b64d3f6eec
commit 20fcd69687
2 changed files with 2 additions and 56 deletions

View File

@@ -38,12 +38,6 @@ struct FmhaFwdKernel
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
using ElementFunctions = typename FmhaPipeline::Problem::ElementFunctions;
// using QElementFunction = typename ElementFunctions::QElementFunction;
// using KElementFunction = typename ElementFunctions::KElementFunction;
// using VElementFunction = typename ElementFunctions::VElementFunction;
// using BiasElementFunction = typename ElementFunctions::BiasElementFunction;
// using LSEElementFunction = typename ElementFunctions::LSEElementFunction;
// using SAccElementFunction = typename ElementFunctions::SAccElementFunction;
using PComputeElementFunction = typename ElementFunctions::PComputeElementFunction;
using OAccElementFunction = typename ElementFunctions::OAccElementFunction;
@@ -133,12 +127,6 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
// QElementFunction q_element_func;
// KElementFunction k_element_func;
// VElementFunction v_element_func;
// BiasElementFunction bias_element_func;
// LSEElementFunction lse_element_func;
// SAccElementFunction s_acc_element_func;
PComputeElementFunction p_compute_element_func;
OAccElementFunction o_acc_element_func;
};
@@ -232,12 +220,6 @@ struct FmhaFwdKernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
// QElementFunction q_element_func,
// KElementFunction k_element_func,
// VElementFunction v_element_func,
// BiasElementFunction bias_element_func,
// LSEElementFunction lse_element_func,
// SAccElementFunction s_acc_element_func,
PComputeElementFunction p_compute_element_func,
OAccElementFunction o_acc_element_func)
{
@@ -263,12 +245,6 @@ struct FmhaFwdKernel
nhead_stride_k,
nhead_stride_v,
nhead_stride_o,
// q_element_func,
// k_element_func,
// v_element_func,
// bias_element_func,
// lse_element_func,
// s_acc_element_func,
p_compute_element_func,
o_acc_element_func}, // args for common karg
{}, // placeholder for bias
@@ -331,12 +307,6 @@ struct FmhaFwdKernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
// QElementFunction q_element_func,
// KElementFunction k_element_func,
// VElementFunction v_element_func,
// BiasElementFunction bias_element_func,
// LSEElementFunction lse_element_func,
// SAccElementFunction s_acc_element_func,
PComputeElementFunction p_compute_element_func,
OAccElementFunction o_acc_element_func)
{
@@ -362,12 +332,6 @@ struct FmhaFwdKernel
nhead_stride_k,
nhead_stride_v,
nhead_stride_o,
// q_element_func,
// k_element_func,
// v_element_func,
// bias_element_func,
// lse_element_func,
// s_acc_element_func,
p_compute_element_func,
o_acc_element_func}, // args for common karg
{}, // placeholder for bias

View File

@@ -7,31 +7,13 @@
namespace ck_tile {
template </*typename QElementFunction_,
typename KElementFunction_,
typename VElementFunction_,
typename BiasElementFunction_,
typename LSEElementFunction_,
typename SAccElementFunction_,*/
typename PComputeElementFunction_,
typename OAccElementFunction_>
/// TODO: support specifying more elementwise functions for input/output tensors
template <typename PComputeElementFunction_, typename OAccElementFunction_>
struct FmhaElementFunctions
{
// using QElementFunction = remove_cvref_t<QElementFunction_>;
// using KElementFunction = remove_cvref_t<KElementFunction_>;
// using VElementFunction = remove_cvref_t<VElementFunction_>;
// using BiasElementFunction = remove_cvref_t<BiasElementFunction_>;
// using LSEElementFunction = remove_cvref_t<LSEElementFunction_>;
// using SAccElementFunction = remove_cvref_t<SAccElementFunction_>;
using PComputeElementFunction = remove_cvref_t<PComputeElementFunction_>;
using OAccElementFunction = remove_cvref_t<OAccElementFunction_>;
// QElementFunction q_element_func;
// KElementFunction k_element_func;
// VElementFunction v_element_func;
// BiasElementFunction bias_element_func;
// LSEElementFunction lse_element_func;
// SAccElementFunction s_acc_element_func;
PComputeElementFunction p_compute_element_func;
OAccElementFunction o_acc_element_func;
};