diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index ee4506591e..358c3d9098 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -430,7 +430,7 @@ template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); template -float fmha_splitkv_(const ck_tile::stream_config&, fmha_fwd_args); +float fmha_fwd_splitkv_(const ck_tile::stream_config&, fmha_fwd_args); // This is the public API, will be generated by script struct fmha_fwd_traits diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index e0b4b65595..d798c7d6cd 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -5,7 +5,7 @@ import argparse import itertools from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union from dataclasses import dataclass import copy import fnmatch @@ -196,10 +196,12 @@ MASK_SIMPLIFIED_CHECK_MAP = { FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - return fmha_fwd_(s, a); + return {F_callee}(s, a); }} """ +FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp" + def get_mask_map(mask : str): if mask == "generic": return MASK_MAP @@ -329,10 +331,49 @@ class FmhaFwdPipeline: if self.F_squant == 't' : n += '_squant' return n +@dataclass +class FmhaFwdSplitKVPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_lse == 't' : n += '_lse' + if self.F_dropout == 't' : n += '_dropout' + if self.F_squant == 't' : n += '_squant' + return n + class FmhaFwdApiPool: - def __init__(self, mask_impl): + def __init__(self, mask_impl, is_splitkv): self.pool = dict() self.mask_impl = mask_impl + self.is_splitkv = is_splitkv def register_traits(self, trait : FmhaFwdApiTrait) -> None: # TODO: do we need to check duplication? @@ -345,6 +386,7 @@ class FmhaFwdApiPool: @property def api(self) -> str: + callee = "fmha_fwd_splitkv_" if self.is_splitkv else "fmha_fwd_" per_dtypes=str() for i, dtype in enumerate(self.pool.keys()): per_hdim_case=str() @@ -360,7 +402,8 @@ class FmhaFwdApiPool: F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, - F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], + F_callee=callee) if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' @@ -474,6 +517,92 @@ class FmhaFwdKernel: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad) +@dataclass +class FmhaFwdSplitKVKernel: + direction : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + def get_tp(self) -> str: + if self.F_mode == 'group': + return 'hbs' + else: + return 'shb' + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_{self.direction}_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + # TODO: design a more practical way to do it # this is current supported tile size per hdim def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: @@ -496,7 +625,10 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[ else: return None -def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: +def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, is_splitkv : bool = False) -> Tuple[FmhaFwdApiPool, List[Union[FmhaFwdKernel, FmhaFwdSplitKVKernel]]]: + Pipeline = FmhaFwdSplitKVPipeline if is_splitkv else FmhaFwdPipeline + Kernel = FmhaFwdSplitKVKernel if is_splitkv else FmhaFwdKernel + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: @@ -510,29 +642,29 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): if hdim == 256: # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) if receipt == 1: - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) else: assert False return pipelines gen = list() - api_pool = FmhaFwdApiPool(mask_impl) + api_pool = FmhaFwdApiPool(mask_impl, is_splitkv) for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()): d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) @@ -547,7 +679,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue - k = FmhaFwdKernel(direction=direction, + k = Kernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -1144,11 +1276,12 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: return gen -def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: +def write_single_fwd_kernel(kernel: Union[FmhaFwdKernel, FmhaFwdSplitKVKernel], autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + file_path = autogen_dir / (FMHA_FWD_SPLITKV_API_FILENAME if api_pool.is_splitkv else FMHA_FWD_API_FILENAME) + file_path.write_text(api_pool.api) def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) @@ -1171,6 +1304,12 @@ def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optio for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) + + # write split-kv blobs + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, True) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) else: kernels = get_bwd_dot_do_o_blobs() for kernel in kernels: @@ -1190,6 +1329,12 @@ def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Opt for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") + + # get split-kv blobs + _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, True) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") else: kernels = get_bwd_dot_do_o_blobs() for kernel in kernels: