Generate fmha_fwd_splitkv()

This commit is contained in:
PoYen, Chen
2024-06-02 13:20:31 +00:00
parent c928fefaae
commit 78b64d11c4
2 changed files with 166 additions and 21 deletions

View File

@@ -430,7 +430,7 @@ template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <typename Traits_>
float fmha_splitkv_(const ck_tile::stream_config&, fmha_fwd_args);
float fmha_fwd_splitkv_(const ck_tile::stream_config&, fmha_fwd_args);
// This is the public API, will be generated by script
struct fmha_fwd_traits

View File

@@ -5,7 +5,7 @@
import argparse
import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
from dataclasses import dataclass
import copy
import fnmatch
@@ -196,10 +196,12 @@ MASK_SIMPLIFIED_CHECK_MAP = {
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
return {F_callee}<trait_>(s, a);
}}
"""
FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp"
def get_mask_map(mask : str):
if mask == "generic":
return MASK_MAP
@@ -329,10 +331,49 @@ class FmhaFwdPipeline:
if self.F_squant == 't' : n += '_squant'
return n
@dataclass
class FmhaFwdSplitKVPipeline:
tag : str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' : n += '_lse'
if self.F_dropout == 't' : n += '_dropout'
if self.F_squant == 't' : n += '_squant'
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
def __init__(self, mask_impl, is_splitkv):
self.pool = dict()
self.mask_impl = mask_impl
self.is_splitkv = is_splitkv
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
@@ -345,6 +386,7 @@ class FmhaFwdApiPool:
@property
def api(self) -> str:
callee = "fmha_fwd_splitkv_" if self.is_splitkv else "fmha_fwd_"
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
@@ -360,7 +402,8 @@ class FmhaFwdApiPool:
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
F_callee=callee)
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
@@ -474,6 +517,92 @@ class FmhaFwdKernel:
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
@dataclass
class FmhaFwdSplitKVKernel:
direction : str
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
def get_tp(self) -> str:
if self.F_mode == 'group':
return 'hbs'
else:
return 'shb'
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0blen = self.F_tile.F_bk0blen,
F_rm = self.F_tile.F_rm,
F_rn = self.F_tile.F_rn,
F_rk = self.F_tile.F_rk,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()])
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_{self.direction}_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0blen=self.F_tile.F_bk0blen,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]:
@@ -496,7 +625,10 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
else:
return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, is_splitkv : bool = False) -> Tuple[FmhaFwdApiPool, List[Union[FmhaFwdKernel, FmhaFwdSplitKVKernel]]]:
Pipeline = FmhaFwdSplitKVPipeline if is_splitkv else FmhaFwdPipeline
Kernel = FmhaFwdSplitKVKernel if is_splitkv else FmhaFwdKernel
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
@@ -510,29 +642,29 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
if receipt == 1:
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
else:
assert False
return pipelines
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
api_pool = FmhaFwdApiPool(mask_impl, is_splitkv)
for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()):
d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype)
@@ -547,7 +679,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
k = FmhaFwdKernel(direction=direction,
k = Kernel(direction=direction,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
@@ -1144,11 +1276,12 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
return gen
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
def write_single_fwd_kernel(kernel: Union[FmhaFwdKernel, FmhaFwdSplitKVKernel], autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
file_path = autogen_dir / (FMHA_FWD_SPLITKV_API_FILENAME if api_pool.is_splitkv else FMHA_FWD_API_FILENAME)
file_path.write_text(api_pool.api)
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
@@ -1171,6 +1304,12 @@ def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optio
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
# write split-kv blobs
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, True)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
else:
kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels:
@@ -1190,6 +1329,12 @@ def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Opt
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
# get split-kv blobs
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, True)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n")
else:
kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels: