Let generate.py can generate different elementwise function

This commit is contained in:
rocking
2024-04-04 03:59:38 +00:00
parent d6cb104d0f
commit 68153dea0b
3 changed files with 89 additions and 52 deletions

View File

@@ -49,7 +49,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using BiasDataType = float; // TODO: fix me
using BiasDataType = float;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
@@ -70,6 +70,18 @@ struct FmhaDefaultElementFunctions
using OAccElementFunction = ck_tile::identity;
};
struct FmhaF8StaticQuantizationElementFunctions
{
using QElementFunction = ck_tile::identity;
using KElementFunction = ck_tile::identity;
using VElementFunction = ck_tile::identity;
using BiasElementFunction = ck_tile::identity;
using LSEElementFunction = ck_tile::identity;
using SAccElementFunction = ck_tile::identity;
using PComputeElementFunction = ck_tile::scale;
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scale>;
};
template <>
struct FmhaFwdTypeConfig<ck_tile::bf8_t>
{
@@ -316,8 +328,8 @@ struct fmha_fwd_args
float descale_sv;
};
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args<FmhaDefaultElementFunctions> args)
template <typename FmhaKernel, typename FmhaFwdArgs>
auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
@@ -450,8 +462,8 @@ struct fmha_fwd_traits_
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args<FmhaDefaultElementFunctions>);
template <typename Traits_, typename FmhaFwdArgs_>
float fmha_fwd_(const ck_tile::stream_config&, FmhaFwdArgs_);
// This is the public API, will be generated by script
struct fmha_fwd_traits
@@ -466,6 +478,9 @@ struct fmha_fwd_traits
bool has_lse;
// TODO: padding check is inside this api
};
template<typename FmhaFwdArgs_>
float fmha_fwd(fmha_fwd_traits,
fmha_fwd_args<FmhaDefaultElementFunctions>,
FmhaFwdArgs_,
const ck_tile::stream_config&);

View File

@@ -46,6 +46,11 @@ PIPELINE_MAP = {
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
}
ELEMENT_FUNC_MAP = {
"no" : "FmhaDefaultElementFunctions",
"f8_static_quant" : "FmhaF8StaticQuantizationElementFunctions",
}
BOOL_MAP = {
"t" : "true",
"f" : "false"
@@ -85,14 +90,16 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
using fmha_mask_{F_idx} = {F_mask};
using fmha_element_function_{F_idx} = ck_tile::FmhaElementFunctions<
typename FmhaDefaultElementFunctions::QElementFunction,
typename FmhaDefaultElementFunctions::KElementFunction,
typename FmhaDefaultElementFunctions::VElementFunction,
typename FmhaDefaultElementFunctions::BiasElementFunction,
typename FmhaDefaultElementFunctions::LSEElementFunction,
typename FmhaDefaultElementFunctions::SAccElementFunction,
typename FmhaDefaultElementFunctions::PComputeElementFunction,
typename FmhaDefaultElementFunctions::OAccElementFunction>;
typename {F_element_func}::QElementFunction,
typename {F_element_func}::KElementFunction,
typename {F_element_func}::VElementFunction,
typename {F_element_func}::BiasElementFunction,
typename {F_element_func}::LSEElementFunction,
typename {F_element_func}::SAccElementFunction,
typename {F_element_func}::PComputeElementFunction,
typename {F_element_func}::OAccElementFunction>;
using fmha_fwd_args_{F_idx} = fmha_fwd_args<{F_element_func}>;
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
@@ -129,7 +136,7 @@ using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F
#include <iostream>
template<>
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args<FmhaDefaultElementFunctions> a)
float fmha_fwd_<trait_{F_idx}, fmha_fwd_args_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args_{F_idx} a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
@@ -143,7 +150,9 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args<Fm
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
FMHA_FWD_API="""
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args<FmhaDefaultElementFunctions> a, const ck_tile::stream_config& s){{
using fmha_fwd_args_ = fmha_fwd_args<{F_element_func}>;
template<>
float fmha_fwd<fmha_fwd_args_>(fmha_fwd_traits t, fmha_fwd_args_ a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
@@ -183,7 +192,7 @@ class FmhaFwdApiTrait:
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0blen : int
bk0blen : int
vlayout : str
mask : str
bias : str # true/false
@@ -283,6 +292,7 @@ class FmhaFwdApiPool:
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
element_func='no'
per_hdim_case=str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
@@ -299,7 +309,8 @@ class FmhaFwdApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_element_func = ELEMENT_FUNC_MAP[element_func],
F_dispatch = per_dtypes)
@dataclass
class FmhaFwdTileSize:
@@ -324,44 +335,46 @@ class FmhaFwdTileSize:
@dataclass
class FmhaFwdKernel:
direction : str
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
direction : str
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
F_element_func : str
@property
def template(self) -> str:
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0blen = self.F_tile.F_bk0blen,
F_rm = self.F_tile.F_rm,
F_rn = self.F_tile.F_rn,
F_rk = self.F_tile.F_rk,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_occupancy = self.F_tile.F_occupancy ,
F_mask = MASK_MAP[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag])
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0blen = self.F_tile.F_bk0blen,
F_rm = self.F_tile.F_rm,
F_rn = self.F_tile.F_rn,
F_rk = self.F_tile.F_rk,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_occupancy = self.F_tile.F_occupancy ,
F_mask = MASK_MAP[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_element_func = ELEMENT_FUNC_MAP[self.F_element_func])
@property
def name(self) -> str:
@@ -454,7 +467,15 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline)
element_func='no'
k = FmhaFwdKernel(direction=direction,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
F_element_func=element_func)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue

View File

@@ -49,6 +49,7 @@
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/unary_element_function.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/random.hpp"