From 68153dea0b90a306b4113d04362d661ffe08b1be Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 4 Apr 2024 03:59:38 +0000 Subject: [PATCH] Let generate.py can generate different elementwise function --- example/ck_tile/01_fmha/fmha_fwd.hpp | 27 +++++-- example/ck_tile/01_fmha/generate.py | 113 ++++++++++++++++----------- include/ck_tile/core.hpp | 1 + 3 files changed, 89 insertions(+), 52 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index a15dcb790a..7306ae1b17 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -49,7 +49,7 @@ struct FmhaFwdTypeConfig using QDataType = ck_tile::fp8_t; using KDataType = ck_tile::fp8_t; using VDataType = ck_tile::fp8_t; - using BiasDataType = float; // TODO: fix me + using BiasDataType = float; using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) using SaccDataType = float; // data type for first gemm accumulation using SMPLComputeDataType = float; // data type for reduction, softmax @@ -70,6 +70,18 @@ struct FmhaDefaultElementFunctions using OAccElementFunction = ck_tile::identity; }; +struct FmhaF8StaticQuantizationElementFunctions +{ + using QElementFunction = ck_tile::identity; + using KElementFunction = ck_tile::identity; + using VElementFunction = ck_tile::identity; + using BiasElementFunction = ck_tile::identity; + using LSEElementFunction = ck_tile::identity; + using SAccElementFunction = ck_tile::identity; + using PComputeElementFunction = ck_tile::scale; + using OAccElementFunction = ck_tile::composer; +}; + template <> struct FmhaFwdTypeConfig { @@ -316,8 +328,8 @@ struct fmha_fwd_args float descale_sv; }; -template -auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) +template +auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { @@ -450,8 +462,8 @@ struct fmha_fwd_traits_ static constexpr bool kPadDv = kPadDv_; }; -template -float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); +template +float fmha_fwd_(const ck_tile::stream_config&, FmhaFwdArgs_); // This is the public API, will be generated by script struct fmha_fwd_traits @@ -466,6 +478,9 @@ struct fmha_fwd_traits bool has_lse; // TODO: padding check is inside this api }; + +template float fmha_fwd(fmha_fwd_traits, - fmha_fwd_args, + FmhaFwdArgs_, const ck_tile::stream_config&); + diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 2ead9de3be..1b1e7ba775 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -46,6 +46,11 @@ PIPELINE_MAP = { "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", } +ELEMENT_FUNC_MAP = { + "no" : "FmhaDefaultElementFunctions", + "f8_static_quant" : "FmhaF8StaticQuantizationElementFunctions", +} + BOOL_MAP = { "t" : "true", "f" : "false" @@ -85,14 +90,16 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, using fmha_mask_{F_idx} = {F_mask}; using fmha_element_function_{F_idx} = ck_tile::FmhaElementFunctions< - typename FmhaDefaultElementFunctions::QElementFunction, - typename FmhaDefaultElementFunctions::KElementFunction, - typename FmhaDefaultElementFunctions::VElementFunction, - typename FmhaDefaultElementFunctions::BiasElementFunction, - typename FmhaDefaultElementFunctions::LSEElementFunction, - typename FmhaDefaultElementFunctions::SAccElementFunction, - typename FmhaDefaultElementFunctions::PComputeElementFunction, - typename FmhaDefaultElementFunctions::OAccElementFunction>; + typename {F_element_func}::QElementFunction, + typename {F_element_func}::KElementFunction, + typename {F_element_func}::VElementFunction, + typename {F_element_func}::BiasElementFunction, + typename {F_element_func}::LSEElementFunction, + typename {F_element_func}::SAccElementFunction, + typename {F_element_func}::PComputeElementFunction, + typename {F_element_func}::OAccElementFunction>; + +using fmha_fwd_args_{F_idx} = fmha_fwd_args<{F_element_func}>; using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::QDataType, @@ -129,7 +136,7 @@ using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F #include template<> -float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args_{F_idx} a) {{ using k_ = fmha_kernel_{F_idx}; if(s.log_level_ > 0) @@ -143,7 +150,9 @@ float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a, const ck_tile::stream_config& s){{ +using fmha_fwd_args_ = fmha_fwd_args<{F_element_func}>; +template<> +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args_ a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} return r; @@ -183,7 +192,7 @@ class FmhaFwdApiTrait: bk0 : int # tile size along qk gemm unroll bn1 : int # tile size along v head_dim bk1 : int # tile size along kv gemm unroll - bk0blen : int + bk0blen : int vlayout : str mask : str bias : str # true/false @@ -283,6 +292,7 @@ class FmhaFwdApiPool: def api(self) -> str: per_dtypes=str() for i, dtype in enumerate(self.pool.keys()): + element_func='no' per_hdim_case=str() for j, hdim in enumerate(self.pool[dtype].keys()): traits=self.pool[dtype][hdim] @@ -299,7 +309,8 @@ class FmhaFwdApiPool: per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_element_func = ELEMENT_FUNC_MAP[element_func], + F_dispatch = per_dtypes) @dataclass class FmhaFwdTileSize: @@ -324,44 +335,46 @@ class FmhaFwdTileSize: @dataclass class FmhaFwdKernel: - direction : str - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdPipeline + direction : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + F_element_func : str @property def template(self) -> str: return FMHA_FWD_KERNEL_HEADER + \ FMHA_FWD_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0blen = self.F_tile.F_bk0blen, - F_rm = self.F_tile.F_rm, - F_rn = self.F_tile.F_rn, - F_rk = self.F_tile.F_rk, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_bias = BOOL_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_occupancy = self.F_tile.F_occupancy , - F_mask = MASK_MAP[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BOOL_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_occupancy = self.F_tile.F_occupancy , + F_mask = MASK_MAP[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_element_func = ELEMENT_FUNC_MAP[self.F_element_func]) @property def name(self) -> str: @@ -454,7 +467,15 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF tile = d[hdim_str] hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): - k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline) + element_func='no' + k = FmhaFwdKernel(direction=direction, + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + F_element_func=element_func) if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 2767ee05b2..50d2ac8d9a 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -49,6 +49,7 @@ #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/unary_element_function.hpp" #include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/random.hpp"