mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
* enable gfx940
* switch between intrinsic mfma routines on mi100/200 and mi300
* fix mfma_int8 on MI300
* disable 2 int8 examples on MI300
* Update cmake-ck-dev.sh
* restore gitignore file
* modify Jenkinsfile to the internal repo
* Bump rocm-docs-core from 0.24.0 to 0.29.0 in /docs/sphinx
Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.24.0 to 0.29.0.
- [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases)
- [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md)
- [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.24.0...v0.29.0)
---
updated-dependencies:
- dependency-name: rocm-docs-core
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* initial enablement of gfx950
* fix clang format
* disable examples 31 and 41 int8 on gfx950
* add code
* fix build wip
* fix xx
* now can build
* naming
* minor fix
* wip fix
* fix macro for exp2; fix warpgemm a/b in transposedC
* unify as tuple_array
* Update the required Python version to 3.9
* Update executable name in test scripts
* re-structure tuple/array to avoid spill
* Merge function templates
* Fix format
* Add constraint to array<> ctor
* Re-use function
* Some minor changes
* remove wrong code in store_raw()
* fix compile issue in transpose
* Rename enum
Rename 'cood_transform_enum' to 'coord_transform_enum'
* let more integral_constant->constant, and formating
* make sure thread_buffer can be tuple/array
* temp fix buffer_store spill
* not using custom data type by default, now we can have ISA-level same code as opt_padding
* fix compile error, fp8 not ready now
* fix fp8 duplicated move/shift/and/or problem
* Default use CK_TILE_FLOAT_TO_FP8_STOCHASTIC rounding mode
* fix scratch in fp8 kernel
* update some readme
* fix merge from upstream
* sync with upstream
* sync upstream again
* sync 22
* remove unused
* fix clang-format
* update README of ck_tile example
* fix several issue
* let python version to be 3.8 as minimal
* remove ck_tile example from default cmake target like all/install/check
* remove mistake
* 1).support receipe in generate.py 2).use simplified mask type 3).change left/right to pass into karg
* fix some bug in group-mode masking and codegen. update README
* F8 quantization for FMHA forward (#1224)
* Add SAccElementFunction, PComputeElementFunction, OAccElementFunction in pipeline
* Add element function to fmha api
* Adjust P elementwise function
* Fix bug of elementwise op, our elementwise op is not inout
* Add some elementwise op, prepare to quantization
* Let generate.py can generate different elementwise function
* To prevent compiler issue, remove the elementwise function we have not used.
* Remove f8 pipeline, we should share the same pipeline even in f8
* Remove remove_cvref_t
* Avoid warning
* Fix wrong fp8 QK/KV block gemm setting
* Check fp8 rounding error in check_err()
* Set fp8 rounding error for check_err()
* Use CK_TILE_FLOAT_TO_FP8_STANDARD as default fp8 rounding mode
* 1. codgen the f8 api and kernel
2. f8 host code
* prevent warning in filter mode
* Remove not-in-use elementwise function kargs
* Remove more not-in-use elementwise function kargs
* Small refinements in C++ source files
* Use conditional_t<> to simplify code
* Support heterogeneous argument for binary function types
* Re-use already-existing scales<> functor template
* Fix wrong value produced by saturating
* Generalize the composes<> template
* Unify saturates<> implementation
* Fix type errors in composes<>
* Extend less_equal<>
* Reuse the existing template less_equal<> in check_err()
* Add equal<float> & equal<double>
* Rename check_err() parameter
* Rename check_err() parameter
* Add FIXME comment for adding new macro in future
* Remove unnecessary cast to void
* Eliminate duplicated code
* Avoid dividing api pool into more than 2 groups
* Use more clear variable names
* Use affirmative condition in if stmt
* Remove blank lines
* Donot perfect forwarding in composes<>
* To fix compile error, revert generate.py back to 4439cc107d
* Fix bug of p element function
* Add compute element op to host softmax
* Remove element function in api interface
* Extract user parameter
* Rename pscale and oscale variable
* rename f8 to fp8
* rename more f8 to fp8
* Add pipeline::operator() without element_functor
* 1. Remove deprecated pipeline enum
2. Refine host code parameter
* Use quantization range as input
* 1. Rename max_dtype to dtype_max.
2. Rename scale to scale_s
3.Add init description
* Refine description
* prevent early return
* unify _squant kernel name in cpp, update README
* Adjust the default range.
* Refine error message and bias range
* Add fp8 benchmark and smoke test
* fix fp8 swizzle_factor=4 case
---------
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Po-Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
606 lines
25 KiB
Python
606 lines
25 KiB
Python
# SPDX-License-Identifier: MIT
|
|
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
# generate kernel instances to speed up compilation
|
|
|
|
import argparse
|
|
import itertools
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple
|
|
from dataclasses import dataclass
|
|
import copy
|
|
import fnmatch
|
|
|
|
DTYPE_MAP = {
|
|
"fp16": "ck_tile::fp16_t",
|
|
"bf16": "ck_tile::bf16_t",
|
|
"fp8" : "ck_tile::fp8_t"
|
|
}
|
|
|
|
DTYPE_BITS = {
|
|
"fp32": 32,
|
|
"fp16": 16,
|
|
"bf16": 16,
|
|
"fp8" : 8,
|
|
"bf8" : 8
|
|
}
|
|
|
|
MASK_IMPL = {
|
|
"generic" : "ck_tile::GenericAttentionMask",
|
|
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
|
|
}
|
|
|
|
MASK_SIMPLIFIED_MAP = {
|
|
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
|
|
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
|
|
}
|
|
|
|
MASK_MAP = {
|
|
"no" : "FmhaMasks::NoMask",
|
|
"causal" : "FmhaMasks::CausalMask",
|
|
"generic" : "FmhaMasks::GenericMask"
|
|
}
|
|
|
|
MODE_MAP = {
|
|
"batch" : "false",
|
|
"group" : "true"
|
|
}
|
|
|
|
LAYOUT_MAP = {
|
|
"row" : "true",
|
|
"col" : "false"
|
|
}
|
|
|
|
PIPELINE_MAP = {
|
|
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
|
|
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
|
}
|
|
|
|
PIPELINE_ENUM_MAP = {
|
|
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
|
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
|
}
|
|
|
|
BOOL_MAP = {
|
|
"t" : "true",
|
|
"f" : "false"
|
|
}
|
|
|
|
DIRECTIONS = ["fwd"]
|
|
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
|
|
|
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
|
// auto generated by generate.py
|
|
#include "fmha_fwd.hpp"
|
|
"""
|
|
|
|
FMHA_FWD_KERNEL_BODY="""
|
|
using fmha_dtype_{F_idx} = {F_dtype};
|
|
|
|
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>;
|
|
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
|
|
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
|
|
|
|
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
|
fmha_block_warps_{F_idx},
|
|
fmha_warp_tile_{F_idx},
|
|
fmha_block_warps_{F_idx},
|
|
fmha_warp_tile_{F_idx},
|
|
{F_vlayout}>;
|
|
|
|
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
|
{F_skpad},
|
|
{F_dpad},
|
|
{F_dvpad},
|
|
{F_bias},
|
|
{F_lse},
|
|
{F_squant},
|
|
{F_occupancy}>;
|
|
using fmha_mask_{F_idx} = {F_mask};
|
|
|
|
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
|
fmha_shape_{F_idx},
|
|
{F_mode},
|
|
fmha_mask_{F_idx},
|
|
fmha_trait_{F_idx}>;
|
|
|
|
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
|
fmha_pipeline_problem_{F_idx}>;
|
|
|
|
using fmha_epilogue_{F_idx} =
|
|
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
|
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
|
{F_spad}, {F_dvpad}>>;
|
|
|
|
using fmha_kernel_{F_idx} =
|
|
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner<fmha_shape_{F_idx}>,
|
|
fmha_pipeline_{F_idx},
|
|
fmha_epilogue_{F_idx}>;
|
|
|
|
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
|
|
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
|
|
|
#include <iostream>
|
|
|
|
template<>
|
|
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
|
{{
|
|
using k_ = fmha_kernel_{F_idx};
|
|
if(s.log_level_ > 0)
|
|
std::cout << ", " << k_::GetName() << std::flush;
|
|
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
|
constexpr dim3 blocks = k_::BlockSize();
|
|
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
|
return ck_tile::launch_kernel<blocks.x, kBlockPerCu>(s, k_{{}}, grids, blocks, 0, kargs);
|
|
}}
|
|
"""
|
|
|
|
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
|
|
FMHA_FWD_API="""
|
|
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
|
|
float r = -1;
|
|
{F_dispatch}
|
|
return r;
|
|
}}
|
|
"""
|
|
|
|
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
|
{F_hdim_case}
|
|
}}
|
|
"""
|
|
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
|
|
{F_inner_dispatch}
|
|
}}
|
|
"""
|
|
MASK_CHECK_MAP = {
|
|
"no" : "t.mask_type == mask_enum::no_mask",
|
|
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
|
|
"generic" : "t.mask_type == mask_enum::window_generic",
|
|
}
|
|
|
|
MASK_SIMPLIFIED_CHECK_MAP = {
|
|
"s_no" : "t.mask_type == mask_enum::no_mask",
|
|
"s_mask" : "t.mask_type != mask_enum::no_mask",
|
|
}
|
|
|
|
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
|
|
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
|
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
|
return fmha_fwd_<trait_>(s, a);
|
|
}}
|
|
"""
|
|
|
|
def get_mask_map(mask : str):
|
|
if mask == "generic":
|
|
return MASK_MAP
|
|
elif mask == "simplified":
|
|
return MASK_SIMPLIFIED_MAP
|
|
else:
|
|
assert False
|
|
return None
|
|
|
|
def get_mask_check_map(mask : str):
|
|
if mask == "generic":
|
|
return MASK_CHECK_MAP
|
|
elif mask == "simplified":
|
|
return MASK_SIMPLIFIED_CHECK_MAP
|
|
else:
|
|
assert False
|
|
return None
|
|
|
|
@dataclass
|
|
class FmhaFwdApiTrait:
|
|
pipeline_tag : str
|
|
# sync with fmha_fwd_traits<>, to generate fallback calls
|
|
hdim : str
|
|
dtype : str # data type
|
|
mode : str # value from MODE_MAP
|
|
bm0 : int # tile size along q seqlen (block size)
|
|
bn0 : int # tile size along qk seqlen
|
|
bk0 : int # tile size along qk gemm unroll
|
|
bn1 : int # tile size along v head_dim
|
|
bk1 : int # tile size along kv gemm unroll
|
|
bk0blen : int
|
|
vlayout : str
|
|
mask : str
|
|
bias : str # true/false
|
|
lse : str #
|
|
squant : str #
|
|
spad : str
|
|
skpad : str
|
|
dpad : str
|
|
dvpad : str
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\
|
|
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
|
|
|
@property
|
|
def scheck(self) -> str:
|
|
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
|
if self.pipeline_tag == 'qr_async':
|
|
if self.spad == 't' : return 'true' # always support
|
|
else : return 'true'
|
|
elif self.pipeline_tag in ['qr']:
|
|
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
|
else : return f'a.seqlen_q % {self.bm0} == 0'
|
|
else: assert False
|
|
|
|
@property
|
|
def skcheck(self) -> str:
|
|
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
|
if self.pipeline_tag == 'qr_async':
|
|
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0'
|
|
else : return f'a.seqlen_k % {self.bn0} == 0'
|
|
elif self.pipeline_tag in ['qr', 'qr_fp8']:
|
|
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
|
else : return f'a.seqlen_k % {self.bn0} == 0'
|
|
else: assert False
|
|
|
|
@property
|
|
def dcheck(self) -> str:
|
|
if self.pipeline_tag == 'qr_async':
|
|
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
|
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
|
else : assert False
|
|
elif self.pipeline_tag in ['qr']:
|
|
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
|
else : return f'a.hdim_q % {self.bk0blen} == 0'
|
|
else: assert False
|
|
|
|
@property
|
|
def dvcheck(self) -> str:
|
|
if self.pipeline_tag == 'qr_async':
|
|
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
|
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
|
else : assert False
|
|
elif self.pipeline_tag in ['qr']:
|
|
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
|
else : return f'a.hdim_v % {self.bk0blen} == 0'
|
|
else: assert False
|
|
|
|
@dataclass
|
|
class FmhaFwdPipeline:
|
|
tag : str
|
|
|
|
F_vlayout : str # row/col
|
|
F_spad : str # true/false
|
|
F_skpad : str #
|
|
F_dpad : str #
|
|
F_dvpad : str #
|
|
F_bias : str # true/false
|
|
F_lse : str #
|
|
F_squant : str #
|
|
F_mask : str # value from MASK_MAP
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
def pad_name() -> str:
|
|
n = ''
|
|
if self.F_spad == 't': n += 's'
|
|
if self.F_skpad == 't' : n += 'sk'
|
|
if self.F_dpad == 't' : n += 'd'
|
|
if self.F_dvpad == 't' : n += 'dv'
|
|
if n != '' : n = 'p' + n
|
|
return n
|
|
pn = pad_name()
|
|
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
|
if pn != '' : n += f'_{pn}'
|
|
if self.F_bias == 't' : n += '_bias'
|
|
if self.F_mask[0:2] == 's_':
|
|
if self.F_mask == 's_mask': n += f'_mask'
|
|
else:
|
|
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
|
if self.F_lse == 't' : n += '_lse'
|
|
if self.F_squant == 't' : n += '_squant'
|
|
return n
|
|
|
|
class FmhaFwdApiPool:
|
|
def __init__(self, mask_impl):
|
|
self.pool = dict()
|
|
self.mask_impl = mask_impl
|
|
|
|
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
|
# TODO: do we need to check duplication?
|
|
if trait.dtype not in self.pool.keys():
|
|
self.pool[trait.dtype] = dict()
|
|
if trait.hdim not in self.pool[trait.dtype].keys():
|
|
self.pool[trait.dtype][trait.hdim] = list()
|
|
|
|
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
|
|
|
@property
|
|
def api(self) -> str:
|
|
per_dtypes=str()
|
|
for i, dtype in enumerate(self.pool.keys()):
|
|
per_hdim_case=str()
|
|
for j, hdim in enumerate(self.pool[dtype].keys()):
|
|
traits=self.pool[dtype][hdim]
|
|
inners=str()
|
|
for k, trait in enumerate(traits):
|
|
if_k = 'if' if k == 0 else 'else if'
|
|
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
|
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
|
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
|
|
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
|
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
|
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
|
|
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
|
if_j = 'if' if j == 0 else 'else if'
|
|
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
|
if_i = 'if' if i == 0 else 'else if'
|
|
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
|
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
|
|
|
@dataclass
|
|
class FmhaFwdTileSize:
|
|
F_bm0 : int # tile size along q seqlen (block size)
|
|
F_bn0 : int # tile size along qk seqlen
|
|
F_bk0 : int # tile size along qk gemm unroll
|
|
F_bn1 : int # tile size along v head_dim
|
|
F_bk1 : int # tile size along kv gemm unroll
|
|
F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
|
F_rm : int # number of warps along q seqlen (block warps)
|
|
F_rn : int # number of warps along k seqlen(not used)
|
|
F_rk : int # number of warps along gemm-k(not used)
|
|
F_wm : int # warp size along m (warp size)
|
|
F_wn : int # warp size along n
|
|
F_wk : int # warp size along k
|
|
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
|
@property
|
|
def name(self) -> str:
|
|
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\
|
|
f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\
|
|
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
|
|
|
@dataclass
|
|
class FmhaFwdKernel:
|
|
direction : str
|
|
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
|
F_hdim : int # hdim
|
|
F_dtype : str # data type
|
|
F_mode : str # value from MODE_MAP
|
|
F_tile : FmhaFwdTileSize
|
|
F_pipeline : FmhaFwdPipeline
|
|
mask_impl : str
|
|
|
|
@property
|
|
def template(self) -> str:
|
|
kernel_body = str()
|
|
return FMHA_FWD_KERNEL_HEADER + \
|
|
FMHA_FWD_KERNEL_BODY.format(
|
|
F_idx = self.F_idx,
|
|
F_hdim = self.F_hdim,
|
|
F_dtype = DTYPE_MAP[self.F_dtype],
|
|
F_bm0 = self.F_tile.F_bm0,
|
|
F_bn0 = self.F_tile.F_bn0,
|
|
F_bk0 = self.F_tile.F_bk0,
|
|
F_bn1 = self.F_tile.F_bn1,
|
|
F_bk1 = self.F_tile.F_bk1,
|
|
F_bk0blen = self.F_tile.F_bk0blen,
|
|
F_rm = self.F_tile.F_rm,
|
|
F_rn = self.F_tile.F_rn,
|
|
F_rk = self.F_tile.F_rk,
|
|
F_wm = self.F_tile.F_wm,
|
|
F_wn = self.F_tile.F_wn,
|
|
F_wk = self.F_tile.F_wk,
|
|
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
|
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
|
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
|
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
|
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
|
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
|
|
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
|
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
|
F_occupancy = self.F_tile.F_occupancy,
|
|
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
|
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
|
F_mode = MODE_MAP[self.F_mode],
|
|
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag])
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
# TODO: we don't encode idx here
|
|
return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" +\
|
|
self.F_tile.name + '_' + self.F_pipeline.name
|
|
|
|
@property
|
|
def filename(self) -> str:
|
|
return self.name + ".cpp"
|
|
|
|
def api_trait(self) -> FmhaFwdApiTrait:
|
|
return FmhaFwdApiTrait(
|
|
pipeline_tag=self.F_pipeline.tag,
|
|
hdim=str(self.F_hdim),
|
|
dtype=self.F_dtype,
|
|
mode=self.F_mode,
|
|
bm0=self.F_tile.F_bm0,
|
|
bn0=self.F_tile.F_bn0,
|
|
bk0=self.F_tile.F_bk0,
|
|
bn1=self.F_tile.F_bn1,
|
|
bk1=self.F_tile.F_bk1,
|
|
bk0blen=self.F_tile.F_bk0blen,
|
|
vlayout=self.F_pipeline.F_vlayout,
|
|
mask=self.F_pipeline.F_mask,
|
|
bias=self.F_pipeline.F_bias,
|
|
lse=self.F_pipeline.F_lse,
|
|
squant=self.F_pipeline.F_squant,
|
|
spad=self.F_pipeline.F_spad,
|
|
skpad=self.F_pipeline.F_skpad,
|
|
dpad=self.F_pipeline.F_dpad,
|
|
dvpad=self.F_pipeline.F_dvpad)
|
|
|
|
# TODO: design a more practical way to do it
|
|
# this is current supported tile size per hdim
|
|
def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]:
|
|
if direction == 'fwd':
|
|
if dtype == 'fp16' or dtype == 'bf16':
|
|
return {
|
|
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1),
|
|
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1),
|
|
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1),
|
|
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1),
|
|
}
|
|
elif dtype == 'fp8' or dtype == 'bf8':
|
|
return {
|
|
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1)
|
|
}
|
|
else:
|
|
return None
|
|
else:
|
|
return None
|
|
|
|
def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
|
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
|
# support this in future
|
|
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
|
# this function will populate a list possible pipelines
|
|
# TODO: the order of List matters! the later in this list will be also be checked later
|
|
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
|
# TODO: how to design this more generic?
|
|
squant = 't' if dtype == 'fp8' else 'f'
|
|
pipelines = []
|
|
if dtype in ['fp16', 'bf16']:
|
|
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]):
|
|
if hdim == 256:
|
|
# if True:
|
|
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
|
|
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
|
|
|
|
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask))
|
|
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask))
|
|
else:
|
|
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, mask))
|
|
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, mask))
|
|
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, mask))
|
|
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, mask))
|
|
if receipt == 1:
|
|
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim
|
|
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim
|
|
elif dtype in ['fp8', 'bf8']:
|
|
# no need lse kernels
|
|
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]):
|
|
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask))
|
|
else:
|
|
assert False
|
|
return pipelines
|
|
|
|
gen = list()
|
|
api_pool = FmhaFwdApiPool(mask_impl)
|
|
|
|
for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()):
|
|
d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype)
|
|
if d == None:
|
|
continue
|
|
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
|
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
|
tile = d[hdim_str]
|
|
hdim = int(hdim_str)
|
|
for pipeline in get_pipelines(dtype, hdim):
|
|
if mode == "group":
|
|
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
|
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
|
continue
|
|
k = FmhaFwdKernel(direction=direction,
|
|
F_idx=0,
|
|
F_hdim=hdim,
|
|
F_dtype=dtype,
|
|
F_mode=mode,
|
|
F_tile=tile,
|
|
F_pipeline=pipeline,
|
|
mask_impl=mask_impl)
|
|
if kernel_filter != None:
|
|
if not fnmatch.fnmatch(k.name, kernel_filter):
|
|
continue
|
|
api_pool.register_traits(k.api_trait())
|
|
gen.append(k)
|
|
|
|
return (api_pool, gen)
|
|
|
|
def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
|
(autogen_dir / kernel.filename).write_text(kernel.template)
|
|
|
|
def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
|
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
|
|
|
def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
|
if output_dir is None:
|
|
output_dir = Path(__file__).parent
|
|
else:
|
|
output_dir = Path(output_dir) / GEN_DIR
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl)
|
|
for kernel in kernels:
|
|
write_single_kernel(kernel, output_dir)
|
|
write_api(api_pool, output_dir)
|
|
|
|
# list all the files that will be generated
|
|
def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
|
assert output_file is not None
|
|
file_path = Path(output_file)
|
|
with file_path.open('a') as f:
|
|
_, kernels = get_blobs(kernel_filter, receipt, mask_impl)
|
|
for kernel in kernels:
|
|
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
|
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
prog="generate",
|
|
description="gen api for CK fmha kernel",
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output_dir",
|
|
required=False,
|
|
help="write all the blobs into a directory"
|
|
)
|
|
parser.add_argument(
|
|
"-l",
|
|
"--list_blobs",
|
|
required=False,
|
|
help="list all the kernels to a file"
|
|
)
|
|
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
|
parser.add_argument(
|
|
"-f",
|
|
"--filter",
|
|
required=False,
|
|
help="filter out kernels that need to generate, using fnmatch module"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-m",
|
|
"--mask",
|
|
default="simplified",
|
|
required=False,
|
|
help="mask implementation, simplified/generic"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-r",
|
|
"--receipt",
|
|
default=0,
|
|
required=False,
|
|
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
|
|
" 1: generate more instance to cover all hdim"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
if args.list_blobs is not None:
|
|
list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask)
|
|
else:
|
|
write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask)
|