mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Let generate.py can generate different elementwise function
This commit is contained in:
@@ -49,7 +49,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using BiasDataType = float; // TODO: fix me
|
||||
using BiasDataType = float;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
@@ -70,6 +70,18 @@ struct FmhaDefaultElementFunctions
|
||||
using OAccElementFunction = ck_tile::identity;
|
||||
};
|
||||
|
||||
struct FmhaF8StaticQuantizationElementFunctions
|
||||
{
|
||||
using QElementFunction = ck_tile::identity;
|
||||
using KElementFunction = ck_tile::identity;
|
||||
using VElementFunction = ck_tile::identity;
|
||||
using BiasElementFunction = ck_tile::identity;
|
||||
using LSEElementFunction = ck_tile::identity;
|
||||
using SAccElementFunction = ck_tile::identity;
|
||||
using PComputeElementFunction = ck_tile::scale;
|
||||
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scale>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
@@ -316,8 +328,8 @@ struct fmha_fwd_args
|
||||
float descale_sv;
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args<FmhaDefaultElementFunctions> args)
|
||||
template <typename FmhaKernel, typename FmhaFwdArgs>
|
||||
auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
@@ -450,8 +462,8 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args<FmhaDefaultElementFunctions>);
|
||||
template <typename Traits_, typename FmhaFwdArgs_>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, FmhaFwdArgs_);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_fwd_traits
|
||||
@@ -466,6 +478,9 @@ struct fmha_fwd_traits
|
||||
bool has_lse;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
template<typename FmhaFwdArgs_>
|
||||
float fmha_fwd(fmha_fwd_traits,
|
||||
fmha_fwd_args<FmhaDefaultElementFunctions>,
|
||||
FmhaFwdArgs_,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
|
||||
@@ -46,6 +46,11 @@ PIPELINE_MAP = {
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
}
|
||||
|
||||
ELEMENT_FUNC_MAP = {
|
||||
"no" : "FmhaDefaultElementFunctions",
|
||||
"f8_static_quant" : "FmhaF8StaticQuantizationElementFunctions",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
"t" : "true",
|
||||
"f" : "false"
|
||||
@@ -85,14 +90,16 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_element_function_{F_idx} = ck_tile::FmhaElementFunctions<
|
||||
typename FmhaDefaultElementFunctions::QElementFunction,
|
||||
typename FmhaDefaultElementFunctions::KElementFunction,
|
||||
typename FmhaDefaultElementFunctions::VElementFunction,
|
||||
typename FmhaDefaultElementFunctions::BiasElementFunction,
|
||||
typename FmhaDefaultElementFunctions::LSEElementFunction,
|
||||
typename FmhaDefaultElementFunctions::SAccElementFunction,
|
||||
typename FmhaDefaultElementFunctions::PComputeElementFunction,
|
||||
typename FmhaDefaultElementFunctions::OAccElementFunction>;
|
||||
typename {F_element_func}::QElementFunction,
|
||||
typename {F_element_func}::KElementFunction,
|
||||
typename {F_element_func}::VElementFunction,
|
||||
typename {F_element_func}::BiasElementFunction,
|
||||
typename {F_element_func}::LSEElementFunction,
|
||||
typename {F_element_func}::SAccElementFunction,
|
||||
typename {F_element_func}::PComputeElementFunction,
|
||||
typename {F_element_func}::OAccElementFunction>;
|
||||
|
||||
using fmha_fwd_args_{F_idx} = fmha_fwd_args<{F_element_func}>;
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
@@ -129,7 +136,7 @@ using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args<FmhaDefaultElementFunctions> a)
|
||||
float fmha_fwd_<trait_{F_idx}, fmha_fwd_args_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args_{F_idx} a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
@@ -143,7 +150,9 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args<Fm
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args<FmhaDefaultElementFunctions> a, const ck_tile::stream_config& s){{
|
||||
using fmha_fwd_args_ = fmha_fwd_args<{F_element_func}>;
|
||||
template<>
|
||||
float fmha_fwd<fmha_fwd_args_>(fmha_fwd_traits t, fmha_fwd_args_ a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
@@ -183,7 +192,7 @@ class FmhaFwdApiTrait:
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0blen : int
|
||||
bk0blen : int
|
||||
vlayout : str
|
||||
mask : str
|
||||
bias : str # true/false
|
||||
@@ -283,6 +292,7 @@ class FmhaFwdApiPool:
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
element_func='no'
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
@@ -299,7 +309,8 @@ class FmhaFwdApiPool:
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_element_func = ELEMENT_FUNC_MAP[element_func],
|
||||
F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
@@ -324,44 +335,46 @@ class FmhaFwdTileSize:
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
direction : str
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
direction : str
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
F_element_func : str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0blen = self.F_tile.F_bk0blen,
|
||||
F_rm = self.F_tile.F_rm,
|
||||
F_rn = self.F_tile.F_rn,
|
||||
F_rk = self.F_tile.F_rk,
|
||||
F_wm = self.F_tile.F_wm,
|
||||
F_wn = self.F_tile.F_wn,
|
||||
F_wk = self.F_tile.F_wk,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_occupancy = self.F_tile.F_occupancy ,
|
||||
F_mask = MASK_MAP[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag])
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0blen = self.F_tile.F_bk0blen,
|
||||
F_rm = self.F_tile.F_rm,
|
||||
F_rn = self.F_tile.F_rn,
|
||||
F_rk = self.F_tile.F_rk,
|
||||
F_wm = self.F_tile.F_wm,
|
||||
F_wn = self.F_tile.F_wn,
|
||||
F_wk = self.F_tile.F_wk,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_occupancy = self.F_tile.F_occupancy ,
|
||||
F_mask = MASK_MAP[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_element_func = ELEMENT_FUNC_MAP[self.F_element_func])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -454,7 +467,15 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline)
|
||||
element_func='no'
|
||||
k = FmhaFwdKernel(direction=direction,
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
F_element_func=element_func)
|
||||
if kernel_filter != None:
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user