From f7d81364f39fc3a9241846130518380e91ea48d2 Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 8 Apr 2024 09:42:26 +0000 Subject: [PATCH] To prevent compiler issue, remove the elementwise function we have not used. --- example/ck_tile/01_fmha/fmha_fwd.cpp | 6 - example/ck_tile/01_fmha/fmha_fwd.hpp | 67 ++++++----- example/ck_tile/01_fmha/generate.py | 12 +- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 105 +++++++++--------- .../fmha/pipeline/fmha_element_functions.hpp | 28 ++--- 5 files changed, 104 insertions(+), 114 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 347e74fbd3..265dc5bb51 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -385,12 +385,6 @@ bool run(const ck_tile::ArgParser& arg_parser) mask.x, ck_tile::identity{}, ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, descale_q * descale_k, descale_v}; }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 7306ae1b17..65be4538e3 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -60,24 +60,24 @@ struct FmhaFwdTypeConfig struct FmhaDefaultElementFunctions { - using QElementFunction = ck_tile::identity; - using KElementFunction = ck_tile::identity; - using VElementFunction = ck_tile::identity; - using BiasElementFunction = ck_tile::identity; - using LSEElementFunction = ck_tile::identity; - using SAccElementFunction = ck_tile::identity; + // using QElementFunction = ck_tile::identity; + // using KElementFunction = ck_tile::identity; + // using VElementFunction = ck_tile::identity; + // using BiasElementFunction = ck_tile::identity; + // using LSEElementFunction = ck_tile::identity; + // using SAccElementFunction = ck_tile::identity; using PComputeElementFunction = ck_tile::identity; using OAccElementFunction = ck_tile::identity; }; struct FmhaF8StaticQuantizationElementFunctions { - using QElementFunction = ck_tile::identity; - using KElementFunction = ck_tile::identity; - using VElementFunction = ck_tile::identity; - using BiasElementFunction = ck_tile::identity; - using LSEElementFunction = ck_tile::identity; - using SAccElementFunction = ck_tile::identity; + // using QElementFunction = ck_tile::identity; + // using KElementFunction = ck_tile::identity; + // using VElementFunction = ck_tile::identity; + // using BiasElementFunction = ck_tile::identity; + // using LSEElementFunction = ck_tile::identity; + // using SAccElementFunction = ck_tile::identity; using PComputeElementFunction = ck_tile::scale; using OAccElementFunction = ck_tile::composer; }; @@ -316,12 +316,12 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_o; ck_tile::index_t mask_y; ck_tile::index_t mask_x; - typename ElementFunctions::QElementFunction q_element_func; - typename ElementFunctions::KElementFunction k_element_func; - typename ElementFunctions::VElementFunction v_element_func; - typename ElementFunctions::BiasElementFunction bias_element_func; - typename ElementFunctions::LSEElementFunction lse_element_func; - typename ElementFunctions::SAccElementFunction s_acc_element_func; + // typename ElementFunctions::QElementFunction q_element_func; + // typename ElementFunctions::KElementFunction k_element_func; + // typename ElementFunctions::VElementFunction v_element_func; + // typename ElementFunctions::BiasElementFunction bias_element_func; + // typename ElementFunctions::LSEElementFunction lse_element_func; + // typename ElementFunctions::SAccElementFunction s_acc_element_func; typename ElementFunctions::PComputeElementFunction p_compute_element_func; typename ElementFunctions::OAccElementFunction o_acc_element_func; float descale_qk; @@ -362,12 +362,12 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args) args.nhead_stride_o, args.mask_y, args.mask_x, - args.q_element_func, - args.k_element_func, - args.v_element_func, - args.bias_element_func, - args.lse_element_func, - args.s_acc_element_func, + // args.q_element_func, + // args.k_element_func, + // args.v_element_func, + // args.bias_element_func, + // args.lse_element_func, + // args.s_acc_element_func, args.p_compute_element_func, args.o_acc_element_func, args.descale_qk, @@ -406,12 +406,12 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args) args.batch_stride_o, args.mask_y, args.mask_x, - args.q_element_func, - args.k_element_func, - args.v_element_func, - args.bias_element_func, - args.lse_element_func, - args.s_acc_element_func, + // args.q_element_func, + // args.k_element_func, + // args.v_element_func, + // args.bias_element_func, + // args.lse_element_func, + // args.s_acc_element_func, args.p_compute_element_func, args.o_acc_element_func, args.descale_qk, @@ -479,8 +479,5 @@ struct fmha_fwd_traits // TODO: padding check is inside this api }; -template -float fmha_fwd(fmha_fwd_traits, - FmhaFwdArgs_, - const ck_tile::stream_config&); - +template +float fmha_fwd(fmha_fwd_traits, FmhaFwdArgs_, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 1b1e7ba775..74086b61ab 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -90,12 +90,12 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, using fmha_mask_{F_idx} = {F_mask}; using fmha_element_function_{F_idx} = ck_tile::FmhaElementFunctions< - typename {F_element_func}::QElementFunction, - typename {F_element_func}::KElementFunction, - typename {F_element_func}::VElementFunction, - typename {F_element_func}::BiasElementFunction, - typename {F_element_func}::LSEElementFunction, - typename {F_element_func}::SAccElementFunction, + // typename {F_element_func}::QElementFunction, + // typename {F_element_func}::KElementFunction, + // typename {F_element_func}::VElementFunction, + // typename {F_element_func}::BiasElementFunction, + // typename {F_element_func}::LSEElementFunction, + // typename {F_element_func}::SAccElementFunction, typename {F_element_func}::PComputeElementFunction, typename {F_element_func}::OAccElementFunction>; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index cfdd21f80e..bc35525a71 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -35,24 +35,23 @@ struct FmhaFwdKernel using ODataType = ck_tile::remove_cvref_t; static constexpr bool kIsFp8 = FmhaPipeline::kIsFp8; - using VLayout = ck_tile::remove_cvref_t; + using VLayout = ck_tile::remove_cvref_t; + using ElementFunctions = typename FmhaPipeline::Problem::ElementFunctions; - using QElementFunction = - ck_tile::remove_cvref_t; - using KElementFunction = - ck_tile::remove_cvref_t; - using VElementFunction = - ck_tile::remove_cvref_t; - using BiasElementFunction = ck_tile::remove_cvref_t< - typename FmhaPipeline::Problem::ElementFunctions::BiasElementFunction>; - using LSEElementFunction = ck_tile::remove_cvref_t< - typename FmhaPipeline::Problem::ElementFunctions::LSEElementFunction>; - using SAccElementFunction = ck_tile::remove_cvref_t< - typename FmhaPipeline::Problem::ElementFunctions::SAccElementFunction>; - using PComputeElementFunction = ck_tile::remove_cvref_t< - typename FmhaPipeline::Problem::ElementFunctions::PComputeElementFunction>; - using OAccElementFunction = ck_tile::remove_cvref_t< - typename FmhaPipeline::Problem::ElementFunctions::OAccElementFunction>; + // using QElementFunction = ck_tile::remove_cvref_t; using KElementFunction = + // ck_tile::remove_cvref_t; using VElementFunction + // = ck_tile::remove_cvref_t; using + // BiasElementFunction = + // ck_tile::remove_cvref_t; + // using LSEElementFunction = + // ck_tile::remove_cvref_t; + // using SAccElementFunction = + // ck_tile::remove_cvref_t; + using PComputeElementFunction = + ck_tile::remove_cvref_t; + using OAccElementFunction = + ck_tile::remove_cvref_t; static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; @@ -140,12 +139,12 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_o; - QElementFunction q_element_func; - KElementFunction k_element_func; - VElementFunction v_element_func; - BiasElementFunction bias_element_func; - LSEElementFunction lse_element_func; - SAccElementFunction s_acc_element_func; + // QElementFunction q_element_func; + // KElementFunction k_element_func; + // VElementFunction v_element_func; + // BiasElementFunction bias_element_func; + // LSEElementFunction lse_element_func; + // SAccElementFunction s_acc_element_func; PComputeElementFunction p_compute_element_func; OAccElementFunction o_acc_element_func; }; @@ -245,12 +244,12 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_o, ck_tile::index_t mask_y, ck_tile::index_t mask_x, - QElementFunction q_element_func, - KElementFunction k_element_func, - VElementFunction v_element_func, - BiasElementFunction bias_element_func, - LSEElementFunction lse_element_func, - SAccElementFunction s_acc_element_func, + // QElementFunction q_element_func, + // KElementFunction k_element_func, + // VElementFunction v_element_func, + // BiasElementFunction bias_element_func, + // LSEElementFunction lse_element_func, + // SAccElementFunction s_acc_element_func, PComputeElementFunction p_compute_element_func, OAccElementFunction o_acc_element_func, float descale_qk, @@ -278,12 +277,12 @@ struct FmhaFwdKernel nhead_stride_k, nhead_stride_v, nhead_stride_o, - q_element_func, - k_element_func, - v_element_func, - bias_element_func, - lse_element_func, - s_acc_element_func, + // q_element_func, + // k_element_func, + // v_element_func, + // bias_element_func, + // lse_element_func, + // s_acc_element_func, p_compute_element_func, o_acc_element_func}, // args for common karg {}, // placeholder for bias @@ -350,12 +349,12 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o, ck_tile::index_t mask_y, ck_tile::index_t mask_x, - QElementFunction q_element_func, - KElementFunction k_element_func, - VElementFunction v_element_func, - BiasElementFunction bias_element_func, - LSEElementFunction lse_element_func, - SAccElementFunction s_acc_element_func, + // QElementFunction q_element_func, + // KElementFunction k_element_func, + // VElementFunction v_element_func, + // BiasElementFunction bias_element_func, + // LSEElementFunction lse_element_func, + // SAccElementFunction s_acc_element_func, PComputeElementFunction p_compute_element_func, OAccElementFunction o_acc_element_func, float descale_qk, @@ -383,12 +382,12 @@ struct FmhaFwdKernel nhead_stride_k, nhead_stride_v, nhead_stride_o, - q_element_func, - k_element_func, - v_element_func, - bias_element_func, - lse_element_func, - s_acc_element_func, + // q_element_func, + // k_element_func, + // v_element_func, + // bias_element_func, + // lse_element_func, + // s_acc_element_func, p_compute_element_func, o_acc_element_func}, // args for common karg {}, // placeholder for bias @@ -719,16 +718,16 @@ struct FmhaFwdKernel else { return FmhaPipeline{}(q_dram_window, - kargs.q_element_func, + identity{}, k_dram_window, - kargs.k_element_func, + identity{}, v_dram_window, - kargs.v_element_func, + identity{}, bias_dram_window, - kargs.bias_element_func, + identity{}, lse_dram_window, - kargs.lse_element_func, - kargs.s_acc_element_func, + identity{}, + identity{}, kargs.p_compute_element_func, kargs.o_acc_element_func, mask, diff --git a/include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp b/include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp index 397b15c6d2..b688c86862 100644 --- a/include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp +++ b/include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp @@ -7,31 +7,31 @@ namespace ck_tile { -template struct FmhaElementFunctions { - using QElementFunction = remove_cvref_t; - using KElementFunction = remove_cvref_t; - using VElementFunction = remove_cvref_t; - using BiasElementFunction = remove_cvref_t; - using LSEElementFunction = remove_cvref_t; - using SAccElementFunction = remove_cvref_t; + // using QElementFunction = remove_cvref_t; + // using KElementFunction = remove_cvref_t; + // using VElementFunction = remove_cvref_t; + // using BiasElementFunction = remove_cvref_t; + // using LSEElementFunction = remove_cvref_t; + // using SAccElementFunction = remove_cvref_t; using PComputeElementFunction = remove_cvref_t; using OAccElementFunction = remove_cvref_t; - QElementFunction q_element_func; - KElementFunction k_element_func; - VElementFunction v_element_func; - BiasElementFunction bias_element_func; - LSEElementFunction lse_element_func; - SAccElementFunction s_acc_element_func; + // QElementFunction q_element_func; + // KElementFunction k_element_func; + // VElementFunction v_element_func; + // BiasElementFunction bias_element_func; + // LSEElementFunction lse_element_func; + // SAccElementFunction s_acc_element_func; PComputeElementFunction p_compute_element_func; OAccElementFunction o_acc_element_func; };