1. codgen the f8 api and kernel

2. f8 host code
This commit is contained in:
rocking
2024-04-08 21:36:23 +00:00
parent 5860f3134a
commit 525b89e538
2 changed files with 182 additions and 88 deletions

View File

@@ -320,7 +320,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.type,
use_bias,
lse};
auto fmha_args = [&]() {
auto pcompute_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return ck_tile::scale{10.f};
else
return ck_tile::identity{};
}();
auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scale{0.1f});
else
return ck_tile::identity{};
}();
auto fmha_args = [&]() {
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
@@ -357,46 +372,90 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
return fmha_fwd_args<FmhaDefaultElementFunctions>{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
hdim_q,
hdim_v,
nhead,
nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
ck_tile::identity{},
ck_tile::identity{}};
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return fmha_fwd_args<FmhaF8StaticQuantizationElementFunctions>{
q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
hdim_q,
hdim_v,
nhead,
nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
pcompute_element_func,
oacc_element_func};
else
return fmha_fwd_args<FmhaDefaultElementFunctions>{
q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
hdim_q,
hdim_v,
nhead,
nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
pcompute_element_func,
oacc_element_func};
}();
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);
@@ -481,7 +540,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
s_host_ref,
ck_tile::identity{},
ck_tile::identity{},
[&](SaccDataType x) { return scale * x; });
[&](SaccDataType x) { return pcompute_element_func(scale * x); });
if(use_bias)
{
@@ -549,7 +608,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref, v_host_ref, o_host_ref);
p_host_ref,
v_host_ref,
o_host_ref,
ck_tile::identity{},
ck_tile::identity{},
oacc_element_func);
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
// clang-format off

View File

@@ -163,9 +163,9 @@ float fmha_fwd_<trait_{F_idx}, fmha_fwd_args_{F_idx}>(const ck_tile::stream_conf
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
FMHA_FWD_API="""
using fmha_fwd_args_ = fmha_fwd_args<{F_element_func}>;
using fmha_fwd_args_{F_idx} = fmha_fwd_args<{F_element_func}>;
template<>
float fmha_fwd<fmha_fwd_args_>(fmha_fwd_traits t, fmha_fwd_args_ a, const ck_tile::stream_config& s){{
float fmha_fwd<fmha_fwd_args_{F_idx}>(fmha_fwd_traits t, fmha_fwd_args_{F_idx} a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
@@ -322,22 +322,28 @@ class FmhaFwdPipeline:
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.pool_f8 = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
if trait.dtype == 'fp8':
if trait.hdim not in self.pool_f8.keys():
self.pool_f8[trait.hdim] = list()
self.pool_f8[trait.hdim].append(copy.copy(trait))
else:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
fmha_fwd_api=str()
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
element_func='no'
per_hdim_case=str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
@@ -345,18 +351,41 @@ class FmhaFwdApiPool:
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_element_func = ELEMENT_FUNC_MAP[element_func],
F_dispatch = per_dtypes)
fmha_fwd_api = fmha_fwd_api + FMHA_FWD_API.format(F_element_func = ELEMENT_FUNC_MAP['no'], F_dispatch = per_dtypes, F_idx=0)
# fp8 quantization
per_hdim_case=str()
for j, hdim in enumerate(self.pool_f8.keys()):
traits=self.pool_f8[hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
F_hdim=hdim, F_dtype=DTYPE_MAP['fp8'])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
fp8_impl = FMHA_FWD_API_PER_DTYPE.format(F_if='if', F_dtype='fp8', F_hdim_case=per_hdim_case)
fmha_fwd_api = fmha_fwd_api + FMHA_FWD_API.format(F_element_func = ELEMENT_FUNC_MAP['f8_static_quant'],
F_dispatch = fp8_impl,
F_idx=1)
return FMHA_FWD_KERNEL_HEADER + fmha_fwd_api
@dataclass
class FmhaFwdTileSize:
@@ -393,36 +422,37 @@ class FmhaFwdKernel:
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0blen = self.F_tile.F_bk0blen,
F_rm = self.F_tile.F_rm,
F_rn = self.F_tile.F_rn,
F_rk = self.F_tile.F_rk,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_occupancy = self.F_tile.F_occupancy,
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0blen = self.F_tile.F_bk0blen,
F_rm = self.F_tile.F_rm,
F_rn = self.F_tile.F_rn,
F_rk = self.F_tile.F_rk,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_element_func = ELEMENT_FUNC_MAP[self.F_element_func])
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_element_func = ELEMENT_FUNC_MAP[self.F_element_func])
@property
def name(self) -> str:
@@ -521,7 +551,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
element_func='no'
element_func = 'no' if dtype != 'fp8' else 'f8_static_quant'
k = FmhaFwdKernel(direction=direction,
F_idx=0,
F_hdim=hdim,