From a9adfbe54a0b4f2aa7ef23f3d4798b0662f6330e Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 9 Apr 2024 06:45:03 +0000 Subject: [PATCH] Small refinements in C++ source files --- example/ck_tile/01_fmha/generate.py | 17 ++++++----------- include/ck_tile/core.hpp | 2 +- include/ck_tile/ops/fmha.hpp | 2 +- .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 3 +-- ...ions.hpp => tile_fmha_element_functions.hpp} | 2 +- 5 files changed, 10 insertions(+), 16 deletions(-) rename include/ck_tile/ops/fmha/pipeline/{fmha_element_functions.hpp => tile_fmha_element_functions.hpp} (95%) diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 739d22904b..8216ce044e 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -55,15 +55,16 @@ PIPELINE_MAP = { "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", } -ELEMENT_FUNC_MAP = { - "no" : "FmhaDefaultElementFunctions", - "f8_static_quant" : "FmhaF8StaticQuantizationElementFunctions", -} PIPELINE_ENUM_MAP = { "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", } +ELEMENT_FUNC_MAP = { + "no" : "FmhaDefaultElementFunctions", + "f8_static_quant" : "FmhaF8StaticQuantizationElementFunctions", +} + BOOL_MAP = { "t" : "true", "f" : "false" @@ -101,13 +102,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_occupancy}>; using fmha_mask_{F_idx} = {F_mask}; -using fmha_element_function_{F_idx} = ck_tile::FmhaElementFunctions< - // typename {F_element_func}::QElementFunction, - // typename {F_element_func}::KElementFunction, - // typename {F_element_func}::VElementFunction, - // typename {F_element_func}::BiasElementFunction, - // typename {F_element_func}::LSEElementFunction, - // typename {F_element_func}::SAccElementFunction, +using fmha_element_function_{F_idx} = ck_tile::TileFmhaElementFunctions< typename {F_element_func}::PComputeElementFunction, typename {F_element_func}::OAccElementFunction>; diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 50d2ac8d9a..bb19c9154b 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -49,10 +49,10 @@ #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" -#include "ck_tile/core/utility/unary_element_function.hpp" #include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/utility/unary_element_function.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 211852f8ff..fa3ede73d2 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -16,7 +16,7 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_element_functions.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 78527f4633..1ada56ddea 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -9,10 +9,9 @@ namespace ck_tile { -// deprecated pipeline // This pipeline is qkv all located in LDS template -struct BlockFmhaPipelineQRKSVSFp8 +struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 { using Problem = remove_cvref_t; using Policy = remove_cvref_t; diff --git a/include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_element_functions.hpp similarity index 95% rename from include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp rename to include/ck_tile/ops/fmha/pipeline/tile_fmha_element_functions.hpp index 5ce86951d8..2985f82059 100644 --- a/include/ck_tile/ops/fmha/pipeline/fmha_element_functions.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_element_functions.hpp @@ -9,7 +9,7 @@ namespace ck_tile { /// TODO: support specifying more elementwise functions for input/output tensors template -struct FmhaElementFunctions +struct TileFmhaElementFunctions { using PComputeElementFunction = remove_cvref_t; using OAccElementFunction = remove_cvref_t;