mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
* FA fwd dropout * FA bwd * epilogue reuse * CMakeLists update * [CK_TILE] support alibi (#1269) * add alibi support * fix code * update code based on comment * Support more hdim * fix fp8 bias * support seqlen_k=0 case * remove unused printf * fix format --------- Co-authored-by: rocking <ChunYu.Lai@amd.com> * now fwd/bwd can build * bwd alibi * add bwd validation stream_config * update generated filenames * update bwd kernel launch * CK_TILE_HOST_DEVICE in philox * Transpose -> transpose * format * format * format * Generate the instance for FA required * format * fix error in WarpGemm --------- Co-authored-by: danyao12 <danyao12> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: rocking <ChunYu.Lai@amd.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> Co-authored-by: Jing Zhang <jizhan@amd.com>
1258 lines
54 KiB
Python
1258 lines
54 KiB
Python
# SPDX-License-Identifier: MIT
|
|
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
# generate kernel instances to speed up compilation
|
|
|
|
import argparse
|
|
import itertools
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple
|
|
from dataclasses import dataclass
|
|
import copy
|
|
import fnmatch
|
|
|
|
DTYPE_MAP = {
|
|
"fp16": "ck_tile::fp16_t",
|
|
"bf16": "ck_tile::bf16_t",
|
|
"fp8" : "ck_tile::fp8_t"
|
|
}
|
|
|
|
DTYPE_BITS = {
|
|
"fp32": 32,
|
|
"fp16": 16,
|
|
"bf16": 16,
|
|
"fp8" : 8,
|
|
"bf8" : 8
|
|
}
|
|
|
|
MASK_IMPL = {
|
|
"generic" : "ck_tile::GenericAttentionMask",
|
|
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
|
|
}
|
|
|
|
MASK_SIMPLIFIED_MAP = {
|
|
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
|
|
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
|
|
}
|
|
|
|
MASK_MAP = {
|
|
"no" : "FmhaMasks::NoMask",
|
|
"causal" : "FmhaMasks::CausalMask",
|
|
"generic" : "FmhaMasks::GenericMask"
|
|
}
|
|
|
|
BIAS_MAP = {
|
|
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
|
|
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
|
|
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
|
|
}
|
|
|
|
# TODO: this is ugly
|
|
BIAS_CHECK_MAP = {
|
|
"no" : "bias_enum::no_bias",
|
|
"bias" : "bias_enum::elementwise_bias",
|
|
"alibi" : "bias_enum::alibi"
|
|
}
|
|
|
|
MODE_MAP = {
|
|
"batch" : "false",
|
|
"group" : "true"
|
|
}
|
|
|
|
LAYOUT_MAP = {
|
|
"row" : "true",
|
|
"col" : "false"
|
|
}
|
|
|
|
PIPELINE_MAP = {
|
|
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
|
|
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
|
}
|
|
|
|
PIPELINE_ENUM_MAP = {
|
|
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
|
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
|
}
|
|
|
|
BOOL_MAP = {
|
|
"t" : "true",
|
|
"f" : "false"
|
|
}
|
|
|
|
TILE_PARTITIONER_MAP = {
|
|
"shb" : "ck_tile::FmhaFwdTilePartitioner_SHB",
|
|
"hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS",
|
|
}
|
|
|
|
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
|
|
|
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
|
// auto generated by generate.py
|
|
#include "fmha_fwd.hpp"
|
|
"""
|
|
|
|
FMHA_FWD_KERNEL_BODY="""
|
|
using fmha_dtype_{F_idx} = {F_dtype};
|
|
|
|
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>;
|
|
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
|
|
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
|
|
|
|
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
|
fmha_block_warps_{F_idx},
|
|
fmha_warp_tile_{F_idx},
|
|
fmha_block_warps_{F_idx},
|
|
fmha_warp_tile_{F_idx},
|
|
{F_vlayout}>;
|
|
|
|
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
|
{F_skpad},
|
|
{F_dpad},
|
|
{F_dvpad},
|
|
{F_bias},
|
|
false,
|
|
{F_lse},
|
|
{F_dropout},
|
|
{F_squant},
|
|
{F_occupancy}>;
|
|
using fmha_mask_{F_idx} = {F_mask};
|
|
|
|
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
|
fmha_shape_{F_idx},
|
|
{F_mode},
|
|
fmha_mask_{F_idx},
|
|
fmha_trait_{F_idx}>;
|
|
|
|
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
|
fmha_pipeline_problem_{F_idx}>;
|
|
|
|
using fmha_epilogue_{F_idx} =
|
|
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
|
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
|
{F_spad}, {F_dvpad}>>;
|
|
|
|
using fmha_kernel_{F_idx} =
|
|
ck_tile::FmhaFwdKernel<{F_tile_partitioner}<fmha_shape_{F_idx}>,
|
|
fmha_pipeline_{F_idx},
|
|
fmha_epilogue_{F_idx}>;
|
|
|
|
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
|
|
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
|
|
|
#include <iostream>
|
|
|
|
template<>
|
|
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
|
{{
|
|
using k_ = fmha_kernel_{F_idx};
|
|
if(s.log_level_ > 0)
|
|
std::cout << ", " << k_::GetName() << std::flush;
|
|
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
|
constexpr dim3 blocks = k_::BlockSize();
|
|
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
|
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
|
}}
|
|
"""
|
|
|
|
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
|
|
FMHA_FWD_API="""
|
|
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
|
|
float r = -1;
|
|
{F_dispatch}
|
|
return r;
|
|
}}
|
|
"""
|
|
|
|
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
|
{F_hdim_case}
|
|
}}
|
|
"""
|
|
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
|
|
{F_inner_dispatch}
|
|
}}
|
|
"""
|
|
MASK_CHECK_MAP = {
|
|
"no" : "t.mask_type == mask_enum::no_mask",
|
|
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
|
|
"generic" : "t.mask_type == mask_enum::window_generic",
|
|
}
|
|
|
|
MASK_SIMPLIFIED_CHECK_MAP = {
|
|
"s_no" : "t.mask_type == mask_enum::no_mask",
|
|
"s_mask" : "t.mask_type != mask_enum::no_mask",
|
|
}
|
|
|
|
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
|
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
|
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
|
return fmha_fwd_<trait_>(s, a);
|
|
}}
|
|
"""
|
|
|
|
def get_mask_map(mask : str):
|
|
if mask == "generic":
|
|
return MASK_MAP
|
|
elif mask == "simplified":
|
|
return MASK_SIMPLIFIED_MAP
|
|
else:
|
|
assert False
|
|
return None
|
|
|
|
def get_mask_check_map(mask : str):
|
|
if mask == "generic":
|
|
return MASK_CHECK_MAP
|
|
elif mask == "simplified":
|
|
return MASK_SIMPLIFIED_CHECK_MAP
|
|
else:
|
|
assert False
|
|
return None
|
|
|
|
@dataclass
|
|
class FmhaFwdApiTrait:
|
|
pipeline_tag : str
|
|
# sync with fmha_fwd_traits<>, to generate fallback calls
|
|
hdim : str
|
|
dtype : str # data type
|
|
mode : str # value from MODE_MAP
|
|
bm0 : int # tile size along q seqlen (block size)
|
|
bn0 : int # tile size along qk seqlen
|
|
bk0 : int # tile size along qk gemm unroll
|
|
bn1 : int # tile size along v head_dim
|
|
bk1 : int # tile size along kv gemm unroll
|
|
bk0blen : int
|
|
vlayout : str
|
|
mask : str
|
|
bias : str #
|
|
lse : str #
|
|
dropout : str
|
|
squant : str #
|
|
spad : str
|
|
skpad : str
|
|
dpad : str
|
|
dvpad : str
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\
|
|
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
|
|
|
@property
|
|
def scheck(self) -> str:
|
|
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
|
if self.pipeline_tag == 'qr_async':
|
|
if self.spad == 't' : return 'true' # always support
|
|
else : return 'true'
|
|
elif self.pipeline_tag in ['qr']:
|
|
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
|
else : return f'a.seqlen_q % {self.bm0} == 0'
|
|
else: assert False
|
|
|
|
@property
|
|
def skcheck(self) -> str:
|
|
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
|
if self.pipeline_tag == 'qr_async':
|
|
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
|
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
|
elif self.pipeline_tag in ['qr', 'qr_fp8']:
|
|
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
|
else : return f'a.seqlen_k % {self.bn0} == 0'
|
|
else: assert False
|
|
|
|
@property
|
|
def dcheck(self) -> str:
|
|
if self.pipeline_tag == 'qr_async':
|
|
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
|
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
|
else : assert False
|
|
elif self.pipeline_tag in ['qr']:
|
|
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
|
else : return f'a.hdim_q % {self.bk0blen} == 0'
|
|
else: assert False
|
|
|
|
@property
|
|
def dvcheck(self) -> str:
|
|
if self.pipeline_tag == 'qr_async':
|
|
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
|
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
|
else : assert False
|
|
elif self.pipeline_tag in ['qr']:
|
|
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
|
else : return f'a.hdim_v % {self.bk0blen} == 0'
|
|
else: assert False
|
|
|
|
@dataclass
|
|
class FmhaFwdPipeline:
|
|
tag : str
|
|
|
|
F_vlayout : str # row/col
|
|
F_spad : str # true/false
|
|
F_skpad : str #
|
|
F_dpad : str #
|
|
F_dvpad : str #
|
|
F_bias : str # true/false
|
|
F_lse : str #
|
|
F_dropout : str #
|
|
F_squant : str #
|
|
F_mask : str # value from MASK_MAP
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
def pad_name() -> str:
|
|
n = ''
|
|
if self.F_spad == 't': n += 's'
|
|
if self.F_skpad == 't' : n += 'sk'
|
|
if self.F_dpad == 't' : n += 'd'
|
|
if self.F_dvpad == 't' : n += 'dv'
|
|
if n != '' : n = 'p' + n
|
|
return n
|
|
pn = pad_name()
|
|
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
|
if pn != '' : n += f'_{pn}'
|
|
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
|
if self.F_mask[0:2] == 's_':
|
|
if self.F_mask == 's_mask': n += f'_mask'
|
|
else:
|
|
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
|
if self.F_lse == 't' : n += '_lse'
|
|
if self.F_dropout == 't' : n += '_dropout'
|
|
if self.F_squant == 't' : n += '_squant'
|
|
return n
|
|
|
|
class FmhaFwdApiPool:
|
|
def __init__(self, mask_impl):
|
|
self.pool = dict()
|
|
self.mask_impl = mask_impl
|
|
|
|
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
|
# TODO: do we need to check duplication?
|
|
if trait.dtype not in self.pool.keys():
|
|
self.pool[trait.dtype] = dict()
|
|
if trait.hdim not in self.pool[trait.dtype].keys():
|
|
self.pool[trait.dtype][trait.hdim] = list()
|
|
|
|
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
|
|
|
@property
|
|
def api(self) -> str:
|
|
per_dtypes=str()
|
|
for i, dtype in enumerate(self.pool.keys()):
|
|
per_hdim_case=str()
|
|
for j, hdim in enumerate(self.pool[dtype].keys()):
|
|
traits=self.pool[dtype][hdim]
|
|
inners=str()
|
|
for k, trait in enumerate(traits):
|
|
if_k = 'if' if k == 0 else 'else if'
|
|
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
|
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
|
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
|
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
|
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
|
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
|
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
|
|
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
|
if_j = 'if' if j == 0 else 'else if'
|
|
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
|
if_i = 'if' if i == 0 else 'else if'
|
|
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
|
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
|
|
|
@dataclass
|
|
class FmhaFwdTileSize:
|
|
F_bm0 : int # tile size along q seqlen (block size)
|
|
F_bn0 : int # tile size along k seqlen
|
|
F_bk0 : int # tile size along qk gemm unroll
|
|
F_bn1 : int # tile size along v head_dim
|
|
F_bk1 : int # tile size along kv gemm unroll
|
|
F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
|
F_rm : int # number of warps along q seqlen (block warps)
|
|
F_rn : int # number of warps along k seqlen(not used)
|
|
F_rk : int # number of warps along gemm-k(not used)
|
|
F_wm : int # warp size along m (warp size)
|
|
F_wn : int # warp size along n
|
|
F_wk : int # warp size along k
|
|
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
|
@property
|
|
def name(self) -> str:
|
|
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\
|
|
f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\
|
|
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
|
|
|
@dataclass
|
|
class FmhaFwdKernel:
|
|
direction : str
|
|
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
|
F_hdim : int # hdim
|
|
F_dtype : str # data type
|
|
F_mode : str # value from MODE_MAP
|
|
F_tile : FmhaFwdTileSize
|
|
F_pipeline : FmhaFwdPipeline
|
|
mask_impl : str
|
|
|
|
def get_tp(self) -> str:
|
|
if self.F_mode == 'group':
|
|
return 'hbs'
|
|
else:
|
|
return 'shb'
|
|
|
|
@property
|
|
def template(self) -> str:
|
|
kernel_body = str()
|
|
return FMHA_FWD_KERNEL_HEADER + \
|
|
FMHA_FWD_KERNEL_BODY.format(
|
|
F_idx = self.F_idx,
|
|
F_hdim = self.F_hdim,
|
|
F_dtype = DTYPE_MAP[self.F_dtype],
|
|
F_bm0 = self.F_tile.F_bm0,
|
|
F_bn0 = self.F_tile.F_bn0,
|
|
F_bk0 = self.F_tile.F_bk0,
|
|
F_bn1 = self.F_tile.F_bn1,
|
|
F_bk1 = self.F_tile.F_bk1,
|
|
F_bk0blen = self.F_tile.F_bk0blen,
|
|
F_rm = self.F_tile.F_rm,
|
|
F_rn = self.F_tile.F_rn,
|
|
F_rk = self.F_tile.F_rk,
|
|
F_wm = self.F_tile.F_wm,
|
|
F_wn = self.F_tile.F_wn,
|
|
F_wk = self.F_tile.F_wk,
|
|
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
|
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
|
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
|
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
|
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
|
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
|
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
|
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
|
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
|
F_occupancy = self.F_tile.F_occupancy,
|
|
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
|
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
|
F_mode = MODE_MAP[self.F_mode],
|
|
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
|
|
F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()])
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
# TODO: we don't encode idx here
|
|
return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \
|
|
self.F_tile.name + '_' + self.F_pipeline.name
|
|
|
|
@property
|
|
def filename(self) -> str:
|
|
return self.name + ".cpp"
|
|
|
|
def api_trait(self) -> FmhaFwdApiTrait:
|
|
return FmhaFwdApiTrait(
|
|
pipeline_tag=self.F_pipeline.tag,
|
|
hdim=str(self.F_hdim),
|
|
dtype=self.F_dtype,
|
|
mode=self.F_mode,
|
|
bm0=self.F_tile.F_bm0,
|
|
bn0=self.F_tile.F_bn0,
|
|
bk0=self.F_tile.F_bk0,
|
|
bn1=self.F_tile.F_bn1,
|
|
bk1=self.F_tile.F_bk1,
|
|
bk0blen=self.F_tile.F_bk0blen,
|
|
vlayout=self.F_pipeline.F_vlayout,
|
|
mask=self.F_pipeline.F_mask,
|
|
bias=self.F_pipeline.F_bias,
|
|
lse=self.F_pipeline.F_lse,
|
|
dropout=self.F_pipeline.F_dropout,
|
|
squant=self.F_pipeline.F_squant,
|
|
spad=self.F_pipeline.F_spad,
|
|
skpad=self.F_pipeline.F_skpad,
|
|
dpad=self.F_pipeline.F_dpad,
|
|
dvpad=self.F_pipeline.F_dvpad)
|
|
|
|
# TODO: design a more practical way to do it
|
|
# this is current supported tile size per hdim
|
|
def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]:
|
|
if direction == 'fwd':
|
|
if dtype == 'fp16' or dtype == 'bf16':
|
|
return {
|
|
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1),
|
|
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1),
|
|
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1),
|
|
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1),
|
|
}
|
|
elif dtype == 'fp8' or dtype == 'bf8':
|
|
return {
|
|
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1),
|
|
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1),
|
|
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1)
|
|
}
|
|
else:
|
|
return None
|
|
else:
|
|
return None
|
|
|
|
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
|
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
|
# support this in future
|
|
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
|
# this function will populate a list possible pipelines
|
|
# TODO: the order of List matters! the later in this list will be also be checked later
|
|
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
|
# TODO: how to design this more generic?
|
|
squant = 't' if dtype == 'fp8' else 'f'
|
|
pipelines = []
|
|
if dtype in ['fp16', 'bf16']:
|
|
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
|
if hdim == 256:
|
|
# if True:
|
|
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
|
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
|
|
|
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
|
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
|
else:
|
|
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
|
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
|
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
|
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
|
if receipt == 1:
|
|
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
|
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
|
elif dtype in ['fp8', 'bf8']:
|
|
# no need lse/dropout kernels
|
|
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
|
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
|
|
else:
|
|
assert False
|
|
return pipelines
|
|
|
|
gen = list()
|
|
api_pool = FmhaFwdApiPool(mask_impl)
|
|
|
|
for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()):
|
|
d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype)
|
|
if d == None:
|
|
continue
|
|
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
|
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
|
tile = d[hdim_str]
|
|
hdim = int(hdim_str)
|
|
for pipeline in get_pipelines(dtype, hdim):
|
|
if mode == "group":
|
|
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
|
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
|
continue
|
|
k = FmhaFwdKernel(direction=direction,
|
|
F_idx=0,
|
|
F_hdim=hdim,
|
|
F_dtype=dtype,
|
|
F_mode=mode,
|
|
F_tile=tile,
|
|
F_pipeline=pipeline,
|
|
mask_impl=mask_impl)
|
|
if kernel_filter != None:
|
|
if not fnmatch.fnmatch(k.name, kernel_filter):
|
|
continue
|
|
if receipt == 2:
|
|
cond = dtype in ['fp16', 'bf16']
|
|
cond &= pipeline.F_vlayout == 'row'
|
|
cond &= pipeline.F_bias in ['no', 'alibi']
|
|
cond &= pipeline.F_squant == 'f'
|
|
if not cond:
|
|
continue
|
|
api_pool.register_traits(k.api_trait())
|
|
gen.append(k)
|
|
|
|
return (api_pool, gen)
|
|
|
|
BWD_DQDKDV_PIPELINE_MAP = {
|
|
"ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR",
|
|
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS",
|
|
"ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR",
|
|
}
|
|
|
|
BWD_DQDKDV_PIPELINE_ENUM_MAP = {
|
|
"ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR",
|
|
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS",
|
|
"ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR",
|
|
}
|
|
|
|
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
|
// auto generated by generate.py
|
|
#include "fmha_bwd.hpp"
|
|
"""
|
|
|
|
FMHA_BWD_DQ_DK_DV_KERNEL_BODY="""
|
|
using fmha_dtype_{F_idx} = {F_dtype};
|
|
|
|
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
|
|
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
|
|
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
|
|
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
|
|
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
|
|
|
|
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
|
// G0&G2 -> GSdP
|
|
// G1&G3 -> GdKV
|
|
// G4 -> GdQ
|
|
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
|
|
fmha_block_warps0_{F_idx},
|
|
fmha_warp_tile_{F_idx},
|
|
fmha_block_warps1_{F_idx},
|
|
fmha_warp_tile_{F_idx},
|
|
fmha_block_warps0_{F_idx},
|
|
fmha_warp_tile_{F_idx},
|
|
fmha_block_warps1_{F_idx},
|
|
fmha_warp_tile_{F_idx},
|
|
fmha_block_warps2_{F_idx},
|
|
fmha_warp_tile_{F_idx}>;
|
|
|
|
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
|
{F_skpad},
|
|
{F_dpad},
|
|
{F_dvpad},
|
|
{F_bias},
|
|
{F_dbias},
|
|
false,
|
|
{F_dropout},
|
|
false,
|
|
{F_occupancy}>;
|
|
using fmha_mask_{F_idx} = {F_mask};
|
|
|
|
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::GemmDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KGradDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VGradDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
|
|
fmha_bwd_shape_{F_idx},
|
|
{F_mode},
|
|
fmha_mask_{F_idx},
|
|
fmha_bwd_trait_{F_idx}>;
|
|
|
|
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<
|
|
fmha_bwd_pipeline_problem_{F_idx}>;
|
|
|
|
using fmha_bwd_dk_epilogue_{F_idx} =
|
|
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
|
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
|
|
false, false>>;
|
|
|
|
using fmha_bwd_dv_epilogue_{F_idx} =
|
|
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
|
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
|
|
false, false>>;
|
|
|
|
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
|
|
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdTilePartitioner<fmha_bwd_shape_{F_idx}>,
|
|
fmha_bwd_pipeline_{F_idx},
|
|
fmha_bwd_dk_epilogue_{F_idx},
|
|
fmha_bwd_dv_epilogue_{F_idx}>;
|
|
|
|
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
|
|
|
#include <iostream>
|
|
|
|
template<>
|
|
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
|
{{
|
|
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
|
if(s.log_level_ > 0)
|
|
std::cout << ", " << k_::GetName() << std::flush;
|
|
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
|
constexpr dim3 blocks = k_::BlockSize();
|
|
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
|
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
|
}}
|
|
|
|
template<>
|
|
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
|
{{
|
|
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
|
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
|
constexpr dim3 blocks = k_::BlockSize();
|
|
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
|
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
|
}}
|
|
|
|
template<>
|
|
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
|
|
{{
|
|
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
|
return k_::GetName();
|
|
}}
|
|
"""
|
|
|
|
FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
|
|
FMHA_BWD_API="""
|
|
#include <iostream>
|
|
|
|
template<typename dot_do_o_trait_, typename dq_dk_dv_trait_>
|
|
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
|
{{
|
|
if(s.log_level_ > 0)
|
|
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
|
return ck_tile::launch_kernel(s,
|
|
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
|
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
|
|
);
|
|
}}
|
|
|
|
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
|
|
float r = -1;
|
|
{F_dispatch}
|
|
return r;
|
|
}}
|
|
"""
|
|
|
|
FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
|
{F_hdim_case}
|
|
}}
|
|
"""
|
|
FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
|
|
{F_inner_dispatch}
|
|
}}
|
|
"""
|
|
|
|
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
|
|
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
|
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
|
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
|
|
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_>(s, a);
|
|
return r;
|
|
}}
|
|
"""
|
|
|
|
@dataclass
|
|
class FmhaBwdDQDKDVApiTrait:
|
|
pipeline : str
|
|
# sync with fmha_bwd_traits<>, to generate fallback calls
|
|
hdim : str
|
|
dtype : str # data type
|
|
mode : str # value from MODE_MAP
|
|
bm0 : int # tile size along q seqlen (block size)
|
|
bn0 : int # tile size along k seqlen
|
|
bhdq : int # q head_dim
|
|
bhdv : int # v head_dim
|
|
mask : str
|
|
bias : str
|
|
dbias : str
|
|
dropout : str
|
|
spad : str
|
|
skpad : str
|
|
dpad : str
|
|
dvpad : str
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
|
|
|
def scheck(self, spad1 : str) -> str:
|
|
if self.mode == 'group':
|
|
return 'true' # always support
|
|
elif self.spad == 't' and spad1 == 't':
|
|
return f'a.seqlen_q % {self.bm0} != 0'
|
|
elif self.spad == 'f' and spad1 == 't':
|
|
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize
|
|
else: # self.skpad == 'f' and skpad1 == 'f'
|
|
return f'a.seqlen_q % 256 == 0' # BlockSize
|
|
|
|
@property
|
|
def skcheck(self) -> str:
|
|
if self.mode == 'group':
|
|
return 'true' # always support
|
|
elif self.skpad == 't':
|
|
return f'a.seqlen_k % {self.bn0} != 0'
|
|
else:
|
|
return f'a.seqlen_k % {self.bn0} == 0'
|
|
|
|
@property
|
|
def dcheck(self) -> str:
|
|
if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0'
|
|
else : return f'a.hdim_q % {self.bhdq} == 0'
|
|
|
|
@property
|
|
def dvcheck(self) -> str:
|
|
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
|
|
else : return f'a.hdim_v % {self.bhdv} == 0'
|
|
|
|
class FmhaBwdApiPool:
|
|
def __init__(self, mask_impl):
|
|
self.dq_dk_dv_pool = dict()
|
|
self.mask_impl = mask_impl
|
|
|
|
def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None:
|
|
# TODO: do we need to check duplication?
|
|
if trait.dtype not in self.dq_dk_dv_pool.keys():
|
|
self.dq_dk_dv_pool[trait.dtype] = dict()
|
|
if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys():
|
|
self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list()
|
|
|
|
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
|
|
|
@property
|
|
def api(self) -> str:
|
|
per_dtypes=str()
|
|
for i, dtype in enumerate(self.dq_dk_dv_pool.keys()):
|
|
per_hdim_case=str()
|
|
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
|
|
traits=self.dq_dk_dv_pool[dtype][hdim]
|
|
inners=str()
|
|
for k, trait in enumerate(traits):
|
|
if_k = 'if' if k == 0 else 'else if'
|
|
for spad1 in ["t", "f"]:
|
|
if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")):
|
|
continue
|
|
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
|
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
|
|
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
|
|
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad])
|
|
|
|
if_j = 'if' if j == 0 else 'else if'
|
|
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
|
if_i = 'if' if i == 0 else 'else if'
|
|
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
|
|
|
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
|
|
|
|
# GEMM0: Q@K=S^T
|
|
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
|
|
# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order)
|
|
# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk)
|
|
# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk)
|
|
# Is it necessary to distinguish between K0~K4?
|
|
@dataclass
|
|
class FmhaBwdDQDKDVTileSize:
|
|
F_bm0 : int # tile size along q seqlen (block size)
|
|
F_bn0 : int # tile size along k seqlen
|
|
F_bk0 : int # tile size along gemm0 unroll(F_bhdq)
|
|
F_bk1 : int # tile size along gemm1 unroll(F_bm0)
|
|
F_bk2 : int # tile size along gemm2 unroll(F_bhdv)
|
|
F_bk3 : int # tile size along gemm3 unroll(F_bm0)
|
|
F_bk4 : int # tile size along gemm4 unroll(F_bn0)
|
|
F_bhdq : int # q head_dim
|
|
F_bhdv : int # v head_dim
|
|
F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2
|
|
F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2
|
|
F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2
|
|
F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3
|
|
F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3
|
|
F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3
|
|
F_rm2 : int # number of warps along k seqlen (block warps) in gemm4
|
|
F_rn2 : int # number of warps along q seqlen (block warps) in gemm4
|
|
F_rk2 : int # number of warps along gemm-k (not used) in gemm4
|
|
F_wm : int # warp size along m (warp size)
|
|
F_wn : int # warp size along n
|
|
F_wk : int # warp size along k
|
|
F_occupancy : int # occupancy
|
|
@property
|
|
def name(self) -> str:
|
|
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
|
|
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
|
|
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}"
|
|
|
|
@dataclass
|
|
class FmhaBwdDQDKDVKernel:
|
|
direction : str
|
|
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
|
F_hdim : int # hdim
|
|
F_dtype : str # data type
|
|
F_tile : FmhaBwdDQDKDVTileSize
|
|
F_spad : str # true/false
|
|
F_skpad : str #
|
|
F_dpad : str #
|
|
F_dvpad : str #
|
|
F_bias : str #
|
|
F_dbias : str #
|
|
F_dropout : str #
|
|
F_mask : str # value from MASK_MAP
|
|
F_mode : str # value from MODE_MAP
|
|
F_pipeline : str
|
|
mask_impl : str
|
|
|
|
@property
|
|
def template(self) -> str:
|
|
return FMHA_BWD_KERNEL_HEADER + \
|
|
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
|
|
F_idx = self.F_idx,
|
|
F_hdim = self.F_hdim,
|
|
F_dtype = DTYPE_MAP[self.F_dtype],
|
|
F_bm0 = self.F_tile.F_bm0,
|
|
F_bn0 = self.F_tile.F_bn0,
|
|
F_bk0 = self.F_tile.F_bk0,
|
|
F_bk1 = self.F_tile.F_bk1,
|
|
F_bk2 = self.F_tile.F_bk2,
|
|
F_bk3 = self.F_tile.F_bk3,
|
|
F_bk4 = self.F_tile.F_bk4,
|
|
F_bhdq = self.F_tile.F_bhdq,
|
|
F_bhdv = self.F_tile.F_bhdv,
|
|
F_rm0 = self.F_tile.F_rm0,
|
|
F_rn0 = self.F_tile.F_rn0,
|
|
F_rk0 = self.F_tile.F_rk0,
|
|
F_rm1 = self.F_tile.F_rm1,
|
|
F_rn1 = self.F_tile.F_rn1,
|
|
F_rk1 = self.F_tile.F_rk1,
|
|
F_rm2 = self.F_tile.F_rm2,
|
|
F_rn2 = self.F_tile.F_rn2,
|
|
F_rk2 = self.F_tile.F_rk2,
|
|
F_wm = self.F_tile.F_wm,
|
|
F_wn = self.F_tile.F_wn,
|
|
F_wk = self.F_tile.F_wk,
|
|
F_spad = BOOL_MAP[self.F_spad],
|
|
F_skpad = BOOL_MAP[self.F_skpad],
|
|
F_dpad = BOOL_MAP[self.F_dpad],
|
|
F_dvpad = BOOL_MAP[self.F_dvpad],
|
|
F_bias = BIAS_MAP[self.F_bias],
|
|
F_dbias = BOOL_MAP[self.F_dbias],
|
|
F_dropout = BOOL_MAP[self.F_dropout],
|
|
F_occupancy = self.F_tile.F_occupancy,
|
|
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
|
|
F_mode = MODE_MAP[self.F_mode],
|
|
F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline],
|
|
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
def pad_name() -> str:
|
|
n = ''
|
|
if self.F_spad == 't': n += 's'
|
|
if self.F_skpad == 't' : n += 'sk'
|
|
if self.F_dpad == 't' : n += 'd'
|
|
if self.F_dvpad == 't' : n += 'dv'
|
|
if n != '' : n = 'p' + n
|
|
return n
|
|
pn = pad_name()
|
|
n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name
|
|
if pn != '' : n += f'_{pn}'
|
|
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
|
if self.F_dbias == 't' : n += '_dbias'
|
|
if self.F_mask[0:2] == 's_':
|
|
if self.F_mask == 's_mask': n += f'_mask'
|
|
else:
|
|
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
|
if self.F_dropout == 't' : n += '_dropout'
|
|
return n
|
|
|
|
@property
|
|
def filename(self) -> str:
|
|
return self.name + ".cpp"
|
|
|
|
def api_trait(self) -> FmhaBwdDQDKDVApiTrait:
|
|
return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline,
|
|
hdim=str(self.F_hdim),
|
|
dtype=self.F_dtype,
|
|
mode=self.F_mode,
|
|
bm0=self.F_tile.F_bm0,
|
|
bn0=self.F_tile.F_bn0,
|
|
bhdq=self.F_tile.F_bhdq,
|
|
bhdv=self.F_tile.F_bhdv,
|
|
mask=self.F_mask,
|
|
bias=self.F_bias,
|
|
dbias=self.F_dbias,
|
|
dropout=self.F_dropout,
|
|
spad=self.F_spad,
|
|
skpad=self.F_skpad,
|
|
dpad=self.F_dpad,
|
|
dvpad=self.F_dvpad)
|
|
|
|
# TODO: design a more practical way to do it
|
|
# this is current supported tile size & pipeline.
|
|
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]:
|
|
if direction == 'bwd':
|
|
if dtype == 'fp16' or dtype == 'bf16':
|
|
return {
|
|
'32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1),
|
|
"qs_ks_vr_dos"],
|
|
'64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
|
|
"qs_ks_vr_dos"],
|
|
'128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
|
|
"ks_vr"]
|
|
}
|
|
else:
|
|
return None
|
|
else:
|
|
return None
|
|
|
|
def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]:
|
|
# TODO: we don't support tuning yet, so pick up one value for pad
|
|
# support this in future
|
|
gen = list()
|
|
api_pool = FmhaBwdApiPool(mask_impl)
|
|
|
|
for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()):
|
|
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype)
|
|
if d == None:
|
|
continue
|
|
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
|
|
tile = d[hdim_str][0]
|
|
ppl = d[hdim_str][1]
|
|
hdim = int(hdim_str)
|
|
if (mode == "group") and (spad == "f" or skpad == "f"):
|
|
continue
|
|
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
|
continue
|
|
k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
|
|
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
|
|
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
|
|
F_pipeline=ppl, mask_impl=mask_impl)
|
|
if kernel_filter != None:
|
|
if not fnmatch.fnmatch(k.name, kernel_filter):
|
|
continue
|
|
if receipt == 2:
|
|
cond = dtype in ['fp16', 'bf16']
|
|
cond &= bias in ['no', 'alibi']
|
|
if not cond:
|
|
continue
|
|
api_pool.register_dq_dk_dv_traits(k.api_trait())
|
|
gen.append(k)
|
|
|
|
return (api_pool, gen)
|
|
|
|
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
|
|
using fmha_dtype_{F_idx} = {F_dtype};
|
|
|
|
using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad},
|
|
{F_dvpad},
|
|
{F_occupancy}>;
|
|
|
|
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
|
|
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
|
|
/* BlockSize = */ 256,
|
|
{F_hdim},
|
|
{F_mode},
|
|
fmha_bwd_dot_do_o_trait_{F_idx}>;
|
|
|
|
using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO<
|
|
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
|
|
|
|
using fmha_bwd_dot_do_o_kernel_{F_idx} =
|
|
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdOGradDotOTilePartitioner</* BlockSize = */ 256>,
|
|
fmha_bwd_dot_do_o_{F_idx}>;
|
|
|
|
using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
|
|
|
|
#include <iostream>
|
|
|
|
template<>
|
|
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
|
{{
|
|
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
|
if(s.log_level_ > 0)
|
|
std::cout << ", " << k_::GetName() << std::flush;
|
|
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
|
constexpr dim3 blocks = k_::BlockSize();
|
|
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
|
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
|
}}
|
|
|
|
template<>
|
|
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
|
{{
|
|
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
|
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
|
constexpr dim3 blocks = k_::BlockSize();
|
|
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
|
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
|
}}
|
|
|
|
template<>
|
|
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
|
|
{{
|
|
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
|
return k_::GetName();
|
|
}}
|
|
"""
|
|
|
|
@dataclass
|
|
class FmhaBwdOGradDotOKernel:
|
|
direction : str
|
|
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
|
F_hdim : int # hdim
|
|
F_dtype : str # data type
|
|
F_spad : str # true/false
|
|
F_dvpad : str #
|
|
F_mode : str # value from MODE_MAP
|
|
F_occupancy : int
|
|
|
|
@property
|
|
def template(self) -> str:
|
|
return FMHA_BWD_KERNEL_HEADER + \
|
|
FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
|
|
F_idx = self.F_idx,
|
|
F_hdim = self.F_hdim,
|
|
F_dtype = DTYPE_MAP[self.F_dtype],
|
|
F_spad = BOOL_MAP[self.F_spad],
|
|
F_dvpad = BOOL_MAP[self.F_dvpad],
|
|
F_mode = MODE_MAP[self.F_mode],
|
|
F_occupancy = self.F_occupancy)
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
def pad_name() -> str:
|
|
n = ''
|
|
if self.F_spad == 't': n += 's'
|
|
if self.F_dvpad == 't' : n += 'dv'
|
|
if n != '' : n = 'p' + n
|
|
return n
|
|
pn = pad_name()
|
|
n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}"
|
|
if pn != '' : n += f'_{pn}'
|
|
return n
|
|
|
|
@property
|
|
def filename(self) -> str:
|
|
return self.name + ".cpp"
|
|
|
|
def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
|
|
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
|
# support this in future
|
|
def get_occupancy(dtype, hdim):
|
|
return 2
|
|
|
|
gen = list()
|
|
|
|
for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()):
|
|
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype)
|
|
if d == None:
|
|
continue
|
|
for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]):
|
|
hdim = int(hdim_str)
|
|
if (mode == "group" and spad == "f"):
|
|
continue
|
|
k = FmhaBwdOGradDotOKernel(direction=direction+"_dot_do_o", F_idx=0, F_hdim=hdim, F_dtype=dtype,
|
|
F_spad=spad, F_dvpad=dvpad, F_mode=mode,
|
|
F_occupancy=get_occupancy(dtype, hdim))
|
|
gen.append(k)
|
|
|
|
return gen
|
|
|
|
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
|
(autogen_dir / kernel.filename).write_text(kernel.template)
|
|
|
|
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
|
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
|
|
|
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
|
|
(autogen_dir / kernel.filename).write_text(kernel.template)
|
|
|
|
def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None:
|
|
(autogen_dir / kernel.filename).write_text(kernel.template)
|
|
|
|
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
|
|
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
|
|
|
|
def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
|
if output_dir is None:
|
|
output_dir = Path(__file__).parent
|
|
else:
|
|
output_dir = Path(output_dir) / GEN_DIR
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
if direction == 'fwd':
|
|
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
|
|
for kernel in kernels:
|
|
write_single_fwd_kernel(kernel, output_dir)
|
|
write_fwd_api(api_pool, output_dir)
|
|
else:
|
|
kernels = get_bwd_dot_do_o_blobs()
|
|
for kernel in kernels:
|
|
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
|
|
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
|
|
for kernel in kernels:
|
|
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
|
|
write_bwd_api(api_pool, output_dir)
|
|
|
|
# list all the files that will be generated
|
|
def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
|
assert output_file is not None
|
|
file_path = Path(output_file)
|
|
with file_path.open('a') as f:
|
|
if direction == 'fwd':
|
|
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
|
|
for kernel in kernels:
|
|
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
|
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
|
else:
|
|
kernels = get_bwd_dot_do_o_blobs()
|
|
for kernel in kernels:
|
|
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
|
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
|
|
for kernel in kernels:
|
|
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
|
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
prog="generate",
|
|
description="gen api for CK fmha kernel",
|
|
)
|
|
parser.add_argument(
|
|
"-d",
|
|
"--direction",
|
|
default='fwd',
|
|
choices=['fwd', 'bwd'],
|
|
required=False,
|
|
help="choose the direction of kernels(default: fwd)"
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output_dir",
|
|
required=False,
|
|
help="write all the blobs into a directory"
|
|
)
|
|
parser.add_argument(
|
|
"-l",
|
|
"--list_blobs",
|
|
required=False,
|
|
help="list all the kernels to a file"
|
|
)
|
|
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
|
parser.add_argument(
|
|
"-f",
|
|
"--filter",
|
|
required=False,
|
|
help="filter out kernels that need to generate, using fnmatch module"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-m",
|
|
"--mask",
|
|
default="simplified",
|
|
required=False,
|
|
help="mask implementation, simplified/generic"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-r",
|
|
"--receipt",
|
|
default=0,
|
|
required=False,
|
|
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
|
|
" 1: generate more instance to cover all hdim\n" + \
|
|
" 2: Only generate instance for Flash attention integration"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
if args.list_blobs is not None:
|
|
list_blobs(args.list_blobs, args.direction, args.filter, int(args.receipt), mask_impl=args.mask)
|
|
else:
|
|
write_blobs(args.output_dir, args.direction, args.filter, int(args.receipt), mask_impl=args.mask)
|