diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 7c61c8042c..972eb1208c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -320,7 +320,22 @@ bool run(const ck_tile::ArgParser& arg_parser) mask.type, use_bias, lse}; - auto fmha_args = [&]() { + + auto pcompute_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::scale{10.f}; + else + return ck_tile::identity{}; + }(); + + auto oacc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scale{0.1f}); + else + return ck_tile::identity{}; + }(); + + auto fmha_args = [&]() { assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// seqlen_k] in this example, hence both the 'batch_stride_bias' & @@ -357,46 +372,90 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); - return fmha_fwd_args{q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - bias_buf.GetDeviceBuffer(), - lse_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - seqstart_q.GetDeviceBuffer(), - seqstart_k.GetDeviceBuffer(), - nullptr, - shape_seqlen_q, - shape_seqlen_k, - batch, - max_seqlen_q, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_lse, - batch_stride_o, - mask.left, - mask.right, - static_cast(mask.type), - ck_tile::identity{}, - ck_tile::identity{}}; + if constexpr(std::is_same_v) + return fmha_fwd_args{ + q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + hdim_q, + hdim_v, + nhead, + nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_lse, + batch_stride_o, + mask.left, + mask.right, + static_cast(mask.type), + pcompute_element_func, + oacc_element_func}; + else + return fmha_fwd_args{ + q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + hdim_q, + hdim_v, + nhead, + nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_lse, + batch_stride_o, + mask.left, + mask.right, + static_cast(mask.type), + pcompute_element_func, + oacc_element_func}; }(); float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); @@ -481,7 +540,7 @@ bool run(const ck_tile::ArgParser& arg_parser) s_host_ref, ck_tile::identity{}, ck_tile::identity{}, - [&](SaccDataType x) { return scale * x; }); + [&](SaccDataType x) { return pcompute_element_func(scale * x); }); if(use_bias) { @@ -549,7 +608,12 @@ bool run(const ck_tile::ArgParser& arg_parser) } ck_tile::reference_batched_gemm( - p_host_ref, v_host_ref, o_host_ref); + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); // clang-format off diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index b51e33990f..c308fd83cd 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -163,9 +163,9 @@ float fmha_fwd_(const ck_tile::stream_conf FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" FMHA_FWD_API=""" -using fmha_fwd_args_ = fmha_fwd_args<{F_element_func}>; +using fmha_fwd_args_{F_idx} = fmha_fwd_args<{F_element_func}>; template<> -float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args_ a, const ck_tile::stream_config& s){{ +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args_{F_idx} a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} return r; @@ -322,22 +322,28 @@ class FmhaFwdPipeline: class FmhaFwdApiPool: def __init__(self, mask_impl): self.pool = dict() + self.pool_f8 = dict() self.mask_impl = mask_impl def register_traits(self, trait : FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() + if trait.dtype == 'fp8': + if trait.hdim not in self.pool_f8.keys(): + self.pool_f8[trait.hdim] = list() + self.pool_f8[trait.hdim].append(copy.copy(trait)) + else: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) @property def api(self) -> str: + fmha_fwd_api=str() per_dtypes=str() for i, dtype in enumerate(self.pool.keys()): - element_func='no' per_hdim_case=str() for j, hdim in enumerate(self.pool[dtype].keys()): traits=self.pool[dtype][hdim] @@ -345,18 +351,41 @@ class FmhaFwdApiPool: for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], - F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, - F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, + F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_element_func = ELEMENT_FUNC_MAP[element_func], - F_dispatch = per_dtypes) + + fmha_fwd_api = fmha_fwd_api + FMHA_FWD_API.format(F_element_func = ELEMENT_FUNC_MAP['no'], F_dispatch = per_dtypes, F_idx=0) + + # fp8 quantization + per_hdim_case=str() + for j, hdim in enumerate(self.pool_f8.keys()): + traits=self.pool_f8[hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, + F_hdim=hdim, F_dtype=DTYPE_MAP['fp8']) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + fp8_impl = FMHA_FWD_API_PER_DTYPE.format(F_if='if', F_dtype='fp8', F_hdim_case=per_hdim_case) + fmha_fwd_api = fmha_fwd_api + FMHA_FWD_API.format(F_element_func = ELEMENT_FUNC_MAP['f8_static_quant'], + F_dispatch = fp8_impl, + F_idx=1) + + return FMHA_FWD_KERNEL_HEADER + fmha_fwd_api @dataclass class FmhaFwdTileSize: @@ -393,36 +422,37 @@ class FmhaFwdKernel: @property def template(self) -> str: + kernel_body = str() return FMHA_FWD_KERNEL_HEADER + \ FMHA_FWD_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0blen = self.F_tile.F_bk0blen, - F_rm = self.F_tile.F_rm, - F_rn = self.F_tile.F_rn, - F_rk = self.F_tile.F_rk, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_bias = BOOL_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_occupancy = self.F_tile.F_occupancy, + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BOOL_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_occupancy = self.F_tile.F_occupancy, F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], - F_element_func = ELEMENT_FUNC_MAP[self.F_element_func]) + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_element_func = ELEMENT_FUNC_MAP[self.F_element_func]) @property def name(self) -> str: @@ -521,7 +551,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw tile = d[hdim_str] hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): - element_func='no' + element_func = 'no' if dtype != 'fp8' else 'f8_static_quant' k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim,