diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 98ced25f49..2ddd52f996 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -38,12 +38,6 @@ struct FmhaFwdKernel using VLayout = ck_tile::remove_cvref_t; using ElementFunctions = typename FmhaPipeline::Problem::ElementFunctions; - // using QElementFunction = typename ElementFunctions::QElementFunction; - // using KElementFunction = typename ElementFunctions::KElementFunction; - // using VElementFunction = typename ElementFunctions::VElementFunction; - // using BiasElementFunction = typename ElementFunctions::BiasElementFunction; - // using LSEElementFunction = typename ElementFunctions::LSEElementFunction; - // using SAccElementFunction = typename ElementFunctions::SAccElementFunction; using PComputeElementFunction = typename ElementFunctions::PComputeElementFunction; using OAccElementFunction = typename ElementFunctions::OAccElementFunction; @@ -133,12 +127,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_o; - // QElementFunction q_element_func; - // KElementFunction k_element_func; - // VElementFunction v_element_func; - // BiasElementFunction bias_element_func; - // LSEElementFunction lse_element_func; - // SAccElementFunction s_acc_element_func; PComputeElementFunction p_compute_element_func; OAccElementFunction o_acc_element_func; }; @@ -232,12 +220,6 @@ struct FmhaFwdKernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, - // QElementFunction q_element_func, - // KElementFunction k_element_func, - // VElementFunction v_element_func, - // BiasElementFunction bias_element_func, - // LSEElementFunction lse_element_func, - // SAccElementFunction s_acc_element_func, PComputeElementFunction p_compute_element_func, OAccElementFunction o_acc_element_func) { @@ -263,12 +245,6 @@ struct FmhaFwdKernel nhead_stride_k, nhead_stride_v, nhead_stride_o, - // q_element_func, - // k_element_func, - // v_element_func, - // bias_element_func, - // lse_element_func, - // s_acc_element_func, p_compute_element_func, o_acc_element_func}, // args for common karg {}, // placeholder for bias @@ -331,12 +307,6 @@ struct FmhaFwdKernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, - // QElementFunction q_element_func, - // KElementFunction k_element_func, - // VElementFunction v_element_func, - // BiasElementFunction bias_element_func, - // LSEElementFunction lse_element_func, - // SAccElementFunction s_acc_element_func, PComputeElementFunction p_compute_element_func, OAccElementFunction o_acc_element_func) { @@ -362,12 +332,6 @@ struct FmhaFwdKernel nhead_stride_k, nhead_stride_v, nhead_stride_o, - // q_element_func, - // k_element_func, - // v_element_func, - // bias_element_func, - // lse_element_func, - // s_acc_element_func, p_compute_element_func, o_acc_element_func}, // args for common karg {}, // placeholder for bias diff --git a/include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp b/include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp index b688c86862..5ce86951d8 100644 --- a/include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp +++ b/include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp @@ -7,31 +7,13 @@ namespace ck_tile { -template +/// TODO: support specifying more elementwise functions for input/output tensors +template struct FmhaElementFunctions { - // using QElementFunction = remove_cvref_t; - // using KElementFunction = remove_cvref_t; - // using VElementFunction = remove_cvref_t; - // using BiasElementFunction = remove_cvref_t; - // using LSEElementFunction = remove_cvref_t; - // using SAccElementFunction = remove_cvref_t; using PComputeElementFunction = remove_cvref_t; using OAccElementFunction = remove_cvref_t; - // QElementFunction q_element_func; - // KElementFunction k_element_func; - // VElementFunction v_element_func; - // BiasElementFunction bias_element_func; - // LSEElementFunction lse_element_func; - // SAccElementFunction s_acc_element_func; PComputeElementFunction p_compute_element_func; OAccElementFunction o_acc_element_func; };