From 9ac2654b5577e9875eebc533e76a4a2b6492e3da Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Sun, 2 Jun 2024 23:32:55 +0000 Subject: [PATCH] Add SplitKV combine kernel codegen logics --- example/ck_tile/01_fmha/fmha_fwd.hpp | 31 +- example/ck_tile/01_fmha/generate.py | 252 +++++++++- include/ck_tile/ops/fmha.hpp | 3 +- .../fmha_fwd_splitkv_combine_kernel.hpp | 439 ++++++++++++++++++ ...a_fwd_splitkv_combine_tile_partitioner.hpp | 35 ++ ...lock_fmha_fwd_splitkv_combine_pipeline.hpp | 433 +++++++++++++++++ ...plitkv_combine_pipeline_default_policy.hpp | 18 + ...ha_fwd_splitkv_pipeline_default_policy.hpp | 0 .../pipeline/block_fmha_pipeline_problem.hpp | 1 + ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 100 ++++ .../ops/fmha/pipeline/tile_fmha_traits.hpp | 14 +- 11 files changed, 1284 insertions(+), 42 deletions(-) create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp delete mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_default_policy.hpp diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index e9ccaac54f..71a80c4d61 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -334,7 +334,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) return ck_tile::make_tuple(kargs, grids); } -#if 0 template auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) { @@ -343,31 +342,18 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) // create group mode kernel argumentszs if constexpr(FmhaFwdSplitKVCombineKernel::kIsGroupMode) { - return FmhaFwdSplitKVCombineKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, + return FmhaFwdSplitKVCombineKernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, args.lse_ptr, args.o_ptr, + args.batch, + args.nhead, + args.max_seqlen_q, args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, args.hdim_v, - args.nhead_q / args.nhead_k, args.num_splits, - args.scale_s, - args.scale_p, args.scale_o, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, args.nhead_stride_lse, args.nhead_stride_o); } @@ -395,7 +381,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) dim3 grids = FmhaFwdSplitKVCombineKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); return ck_tile::make_tuple(kargs, grids); } -#endif // this is used to pattern-match internl kernel implementation, not to instantiate kernel template std::string fmha_fwd_splitkv_get_name_(); +template +float fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); + +template +std::string fmha_fwd_splitkv_combine_get_name_(); + // This is the public API, will be generated by script struct fmha_fwd_traits { diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index f27dbf1d47..2d103558ea 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -249,6 +249,86 @@ std::string fmha_fwd_splitkv_get_name_() }} """ +FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; +using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; +using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + {F_occupancy}, + 16>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdSplitKVCombineKernel, + fmha_pipeline_{F_idx}, + fmha_epilogue_{F_idx}>; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +std::string fmha_fwd_splitkv_combine_get_name_() +{{ + using k_ = fmha_kernel_{F_idx}; + return k_::GetName(); +}} +""" + FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" FMHA_FWD_API=""" float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ @@ -288,19 +368,18 @@ FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp" FMHA_FWD_SPLITKV_API=""" #include -template +template float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ if(s.log_level_ > 0) - std::cout << ", " << fmha_fwd_splitkv_get_name_() - // << ", " << fmha_fwd_splitkv_combine_get_name_() + std::cout + << ", " << fmha_fwd_splitkv_get_name_() + << ", " << fmha_fwd_splitkv_combine_get_name_() << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }} - // , [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} ); }} @@ -313,8 +392,8 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - return fmha_fwd_splitkv_(s, a); + using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + return fmha_fwd_splitkv_(s, a); }} """ @@ -713,6 +792,85 @@ class FmhaFwdSplitKVKernel: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad) +@dataclass +class FmhaFwdSplitKVCombineKernel: + direction : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_{self.direction}_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + # TODO: design a more practical way to do it # this is current supported tile size per hdim def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: @@ -812,6 +970,74 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, is_splitkv return (api_pool, gen) +def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> List[FmhaFwdSplitKVCombineKernel]: + Pipeline = FmhaFwdSplitKVPipeline + Kernel = FmhaFwdSplitKVCombineKernel + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + if hdim == 256: + # if True: + pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + else: + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + if receipt == 1: + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # no need lse/dropout kernels + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + else: + assert False + return pipelines + + gen = list() + + for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()): + d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = Kernel(direction=direction, + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + gen.append(k) + + return gen + BWD_DQDKDV_PIPELINE_MAP = { "ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR", "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS", @@ -1386,7 +1612,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: return gen -def write_single_fwd_kernel(kernel: Union[FmhaFwdKernel, FmhaFwdSplitKVKernel], autogen_dir: Path) -> None: +def write_single_fwd_kernel(kernel: Union[FmhaFwdKernel, FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: @@ -1416,6 +1642,9 @@ def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optio write_fwd_api(api_pool, output_dir) # write split-kv blobs + kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, True) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) @@ -1441,6 +1670,9 @@ def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Opt f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") # get split-kv blobs + kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, True) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 627eb1abf9..057d2b11ff 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -10,6 +10,8 @@ #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" @@ -26,7 +28,6 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp new file mode 100644 index 0000000000..7eddc3cede --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -0,0 +1,439 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +template +struct FmhaFwdSplitKVCombineKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + __host__ static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using gbr = typename bfs::Gemm0BlockWarps; + using gwt = typename bfs::Gemm0WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_splitkv_combine_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct EmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct CommonKargs + { + const void* lse_acc_ptr; + const void* o_acc_ptr; + void* o_ptr; + + ck_tile::index_t batch; + ck_tile::index_t nhead; + ck_tile::index_t max_seqlen_q; + + ck_tile::index_t seqlen_q; + ck_tile::index_t hdim_v; + ck_tile::index_t num_splits; + + ck_tile::index_t row_stride_o; + ck_tile::index_t nhead_stride_o; + }; + + struct CommonLSEKargs + { + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + }; + + struct Fp8StaticQuantKargs + { + float scale_o; + }; + + struct BatchModeLSEKargs : CommonLSEKargs + { + ck_tile::index_t batch_stride_lse = 0; + }; + + struct BatchModeKargs + : CommonKargs, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_o; + }; + + struct GroupModeKargs + : CommonKargs, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + }; + + using Kargs = std::conditional_t; + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* lse_acc_ptr, + const void* o_acc_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits, + float scale_o, + ck_tile::index_t row_stride_o, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o) + { + Kargs kargs{{lse_acc_ptr, + o_acc_ptr, + o_ptr, + batch, + nhead, + max_seqlen_q, + seqlen_q, + hdim_v, + num_splits, + row_stride_o, + nhead_stride_o}, // args for common karg + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + batch_stride_o}; + + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_o = scale_o; + } + + return kargs; + } + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* lse_acc_ptr, + const void* o_acc_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + const void* seqstart_q_ptr, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits, + float scale_o, + ck_tile::index_t row_stride_o, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o) + { + Kargs kargs{{lse_acc_ptr, + o_acc_ptr, + o_ptr, + batch, + nhead, + max_seqlen_q, + -1, // seqlen will be updated by another pointer + hdim_v, + num_splits, + row_stride_o, + nhead_stride_o}, // args for common karg + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + reinterpret_cast(seqstart_q_ptr)}; + + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_o = scale_o; + } + + return kargs; + } + + __host__ static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_); + } + + __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + + long_index_t batch_offset_lse_acc = 0; + long_index_t batch_offset_o_acc = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + + batch_offset_lse_acc = query_start; + batch_offset_o_acc = query_start * kargs.hdim_v; + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + batch_offset_o = query_start * kargs.row_stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + } + else + { + batch_offset_lse_acc = + static_cast(i_batch) * (kargs.nhead * kargs.max_seqlen_q); + batch_offset_o_acc = static_cast(i_batch) * + (kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v); + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + + // for simplicity, batch stride we just modify the pointer + const LSEDataType* lse_acc_ptr = reinterpret_cast(kargs.lse_acc_ptr) + + static_cast(i_nhead) * (kargs.max_seqlen_q) + + batch_offset_lse_acc; + const OaccDataType* o_acc_ptr = + reinterpret_cast(kargs.o_acc_ptr) + + static_cast(i_nhead) * (kargs.max_seqlen_q * kargs.hdim_v) + + batch_offset_o_acc; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // LSEacc/Oacc DRAM and DRAM windows + const auto lse_acc_dram = [&]() { + const auto lse_acc_dram_naive = make_naive_tensor_view( + lse_acc_ptr, + make_tuple(kargs.num_splits, kargs.seqlen_q), + make_tuple(kargs.batch * kargs.nhead * kargs.max_seqlen_q, 1), + number<8>{}, + number<1>{}); + + return pad_tensor_view( + lse_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_acc_dram = [&]() { + const auto o_acc_dram_naive = make_naive_tensor_view( + o_acc_ptr, + make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v), + make_tuple( + kargs.batch * kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v, kargs.hdim_v, 1), + number{}, + number<1>{}); + + auto o_acc_dram_view = pad_tensor_view( + o_acc_dram_naive, + make_tuple(number<1>{}, number{}, number{}), + sequence{}); + + const index_t new_seqlen_q = + integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) * FmhaPipeline::kM0; + const index_t new_hdim_v = + integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1) * FmhaPipeline::kN1; + + return transform_tensor_view( + o_acc_dram_view, + make_tuple(make_merge_transform(make_tuple(kargs.num_splits, new_seqlen_q)), + make_pass_through_transform(new_hdim_v)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + }(); + + auto lse_acc_dram_window = make_tile_window( + lse_acc_dram, + [&]() { + return make_tuple(number{}, number{}); + }(), + {0, i_m0}); + + auto o_acc_dram_window = make_tile_window( + o_acc_dram, + [&]() { + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + // LSE DRAM window + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + auto o_acc_tile = [&]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}( + lse_acc_dram_window, + o_acc_dram_window, + lse_dram_window, + identity{}, // lse_element_func + composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func + smem_ptr, + kargs.num_splits, + kargs.max_seqlen_q); + } + else + { + return FmhaPipeline{}(lse_acc_dram_window, + o_acc_dram_window, + lse_dram_window, + smem_ptr, + kargs.num_splits, + kargs.max_seqlen_q); + } + }(); + + // O DRAM and DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.row_stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, 0}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp new file mode 100644 index 0000000000..c2b79db9aa --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaFwdSplitKVCombineTilePartitioner +{ + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; + static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN0; + // constexpr static ck_tile::index_t kBlockM = kN1 % 128 == 0 ? 4 : (kN1 % 64 == 0 ? 8 : 16); + + __host__ static constexpr auto + GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q) + { + return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0), nhead, batch_size); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t /*hdim_v*/) + { + const index_t i_tile_m = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_tile_m, i_nhead, i_batch); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp new file mode 100644 index 0000000000..0e11cbb98d --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -0,0 +1,433 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaFwdSplitKVCombinePipeline +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using QDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr index_t kMaxSplits = Problem::kMaxSplits; + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kK0BlockLength <= 32) + { + return 2; + } + else if constexpr(kK0BlockLength <= 64) + { + return 3; + } + else if constexpr(kK0BlockLength <= 128) + { + return 2; + } + else if constexpr(kK0BlockLength <= 256) + { + return 1; + } + } + }(); + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + /// TODO: add padding to avoid bank conflict + return (kM0 * kMaxSplits * sizeof(LSEDataType)); + } + +#define MARKER(msg) \ + __builtin_amdgcn_sched_barrier(0); \ + asm volatile("; [POYENC] " msg ::); \ + __builtin_amdgcn_sched_barrier(0) + + template + CK_TILE_HOST_DEVICE auto + operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, + const OaccDramBlockWindowTmp& o_acc_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, + const LSEElementFunction& lse_element_func, + const OaccElementFunction& o_acc_element_func, + void* smem_ptr, + index_t num_splits, + index_t max_seqlen_q) const + { + LSEDataType* lse_acc_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + + auto lse_acc_dist = Policy::template MakeLSEaccDramTileDistribution(); + auto lse_acc_dram_window = + make_tile_window(lse_acc_dram_block_window_tmp.get_bottom_tensor_view(), + lse_acc_dram_block_window_tmp.get_window_lengths(), + lse_acc_dram_block_window_tmp.get_window_origin(), + lse_acc_dist); + + auto lse_acc = load_tile(lse_acc_dram_window); // [kMaxSplits, kM0] + +#if defined(ENABLE_DEBUG_STMTS) +#define DEBUG_STMTS if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID) +#else +#define DEBUG_STMTS if(false) +#endif + // copy lse_acc to LDS + { + using DataType = LSEDataType; + using StaticTileDistribution = decltype(lse_acc_dist); + + constexpr auto out_spans = + static_distributed_tensor::get_distributed_spans(); + sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) { + constexpr auto distributed_indices = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + StaticTileDistribution{}, distributed_indices); + + const auto row = x_indices.at(number<0>{}); + const auto col = x_indices.at(number<1>{}); + + lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices); + }); + }); + } + block_sync_lds(); + + auto lse_accum_dist = Policy::template MakeLSEaccTDramTileDistribution(); + auto lse_accum = make_static_distributed_tensor(lse_accum_dist); + + // copy LDS to lse_accum + { + using DataType = LSEDataType; + using StaticTileDistribution = decltype(lse_accum_dist); + constexpr auto out_spans = + static_distributed_tensor::get_distributed_spans(); + sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) { + constexpr auto distributed_indices = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + StaticTileDistribution{}, distributed_indices); + + const auto row = x_indices.at(number<0>{}); + const auto col = x_indices.at(number<1>{}); + + if(col < num_splits) + { + lse_accum(distributed_indices) = lse_acc_lds_ptr[col + row * kMaxSplits]; + } + else + { + lse_accum(distributed_indices) = -numeric::infinity(); + } + + DEBUG_STMTS + { + printf("[POYENC][DEVICE] lse_accum[%2d,%2d]: %11.7f\n", + row, + col, + lse_accum(distributed_indices)); + } + }); + }); + } + + // calculate row_max of lse_accum + const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + auto lse_max = block_tile_reduce( + lse_accum, sequence<1>{}, f_max, -numeric::infinity()); + block_tile_reduce_sync(lse_max, f_max, bool_constant{}); + +#if defined(PRINT_LSE_MAX) + DEBUG_STMTS + { + constexpr auto out_spans = + static_distributed_tensor:: + get_distributed_spans(); + sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { + constexpr auto distributed_indices = make_tuple(idx0); + const auto x_indices = get_x_indices_from_distributed_indices( + lse_max.get_tile_distribution(), distributed_indices); + + const auto row = x_indices.at(number<0>{}); + + printf( + "[POYENC][DEVICE] lse_max[%2d]: %11.7f\n", row, lse_max(distributed_indices)); + }); + } +#endif + + static const auto get_validated_m = [](LSEDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + auto p_compute = make_static_distributed_tensor( + lse_accum.get_tile_distribution()); // Pcompute{j} + clear_tile(p_compute); + { + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + p_compute.get_tile_distribution(), i_j_idx); + + const auto row = x_indices.at(number<0>{}); + const auto col = x_indices.at(number<1>{}); + +#if 0 + // from dist tensor + p_compute(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx))); +#else + if (col < num_splits) { + // from shared memory + p_compute(i_j_idx) = ck_tile::exp(lse_acc_lds_ptr[col + row * kMaxSplits] - + get_validated_m(lse_max(i_idx))); + } +#endif +#if 0 + DEBUG_STMTS + { + printf("[POYENC][DEVICE] p_compute[%2d,%2d]: %11.7f\n", + row, + col, + p_compute(i_j_idx)); + } +#endif + }); + }); + } + __syncthreads(); + + auto lse_sum = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, type_convert(0)); + block_tile_reduce_sync(lse_sum, f_sum, bool_constant{}); + +#if defined(PRINT_LSE_SUM) + DEBUG_STMTS + { + constexpr auto out_spans = + static_distributed_tensor:: + get_distributed_spans(); + sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { + constexpr auto distributed_indices = make_tuple(idx0); + const auto x_indices = get_x_indices_from_distributed_indices( + lse_sum.get_tile_distribution(), distributed_indices); + + const auto row = x_indices.at(number<0>{}); + + printf( + "[POYENC][DEVICE] lse_sum[%2d]: %11.7f\n", row, lse_sum(distributed_indices)); + }); + } +#endif + + decltype(lse_max) lse_logsum; + { + constexpr auto out_spans = static_distributed_tensor< + LSEDataType, + decltype(lse_logsum.get_tile_distribution())>::get_distributed_spans(); + sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { + constexpr auto distributed_indices = make_tuple(idx0); + + if(lse_sum(distributed_indices) == 0.f || + lse_sum(distributed_indices) != lse_sum(distributed_indices)) + { + lse_logsum(distributed_indices) = numeric::infinity(); + } + else + { + lse_logsum(distributed_indices) = + ck_tile::log(lse_sum(distributed_indices)) + lse_max(distributed_indices); + } + +#if defined(PRINT_LSE_LOGSUM) + DEBUG_STMTS + { + const auto x_indices = get_x_indices_from_distributed_indices( + lse_logsum.get_tile_distribution(), distributed_indices); + + const auto row = x_indices.at(number<0>{}); + printf("[POYENC][DEVICE] lse_logsum[%d]: %11.7f\n", + row, + lse_logsum(distributed_indices)); + } +#endif + }); + } + +#if defined(PRINT_LSE_ACCUM) + DEBUG_STMTS + { + for(index_t row = 0; row < kM0; ++row) + { + printf("[POYENC][DEVICE] lse_accum[%d] = ", row); + for(index_t col = 0; col < num_splits; ++col) + { + printf("%11.7f", lse_acc_lds_ptr[col + row * kMaxSplits]); + } + printf("\n"); + } + } +#endif + + // write lse scales into LDS + { + constexpr auto out_spans = + static_distributed_tensor:: + get_distributed_spans(); + sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { + constexpr auto distributed_indices = make_tuple(idx0); + const auto x_indices = get_x_indices_from_distributed_indices( + lse_sum.get_tile_distribution(), distributed_indices); + + const auto row = x_indices.at(number<0>{}); + + for(index_t col = 0; col < num_splits; ++col) + { + lse_acc_lds_ptr[col + row * kMaxSplits] = ck_tile::exp( + lse_acc_lds_ptr[col + row * kMaxSplits] - lse_logsum(distributed_indices)); + } + }); + } + block_sync_lds(); + +#if defined(PRINT_LSE_SCALE) + DEBUG_STMTS + { + for(index_t row = 0; row < 32; ++row) + { + printf("[POYENC][DEVICE] lse_scale[%2d] = ", row); + for(index_t col = 0; col < num_splits; ++col) + { + printf("%11.7f", lse_acc_lds_ptr[col + row * kMaxSplits]); + } + printf("\n"); + } + } +#endif + + if constexpr(kStoreLSE) + { + static_assert(kBlockSize == 256); + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); + } + + auto o_acc_dist = Policy::template MakeOaccDramTileDistribution(); + auto o_acc_dram_window = + make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(), + o_acc_dram_block_window_tmp.get_window_lengths(), + o_acc_dram_block_window_tmp.get_window_origin(), + o_acc_dist); + auto o_acc = make_static_distributed_tensor(o_acc_dist); // Pcompute{j} + clear_tile(o_acc); + + // [POYENC] added + for(index_t i_split = 0; i_split < num_splits; ++i_split) + { + auto o_tile = load_tile(o_acc_dram_window); + { + using DataType = OaccDataType; + constexpr auto out_spans = + static_distributed_tensor::get_distributed_spans(); + sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) { + constexpr auto distributed_indices = make_tuple(idx0, idx1); + const auto x_indices = + get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices); + + const auto row = x_indices.at(number<0>{}); + + LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits]; + o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices); + }); + }); + } + + move_tile_window(o_acc_dram_window, {max_seqlen_q, 0}); + } + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window, + const OaccDramBlockWindow& o_acc_dram_block_window, + LSEDramBlockWindow& lse_dram_block_window, + void* smem_ptr, + index_t num_splits, + index_t max_seqlen_q) const + { + return operator()(lse_acc_dram_block_window, + o_acc_dram_block_window, + lse_dram_block_window, + identity{}, + identity{}, + smem_ptr, + num_splits, + max_seqlen_q); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp new file mode 100644 index 0000000000..98f2e4dd1b --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +using BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_default_policy.hpp deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 1b72b60054..b0be881b29 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -52,6 +52,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static constexpr index_t kMaxSplits = Traits::kMaxSplits; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 12af81bb98..d4adf9792f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -985,6 +985,106 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; return BlockGemmARegBSmemCRegV2{}; } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution() + { + using LSEDataType = remove_cvref_t; + + constexpr index_t kBlockSize = 256; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kMPerBlock = Problem::kMaxSplits; + + constexpr index_t NumElements = (kMPerBlock * kNPerBlock); + + if constexpr(NumElements < kBlockSize) + { + static_assert(false); + } + else + { + static_assert(sizeof(LSEDataType) == 4); + + constexpr index_t NPerThread = 16 / sizeof(LSEDataType); // 4 + constexpr index_t NThreads = kNPerBlock / NPerThread; // 32 + static_assert(NThreads <= kBlockSize); + + constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads; // 2 + constexpr index_t TotalWarps = kBlockSize / get_warp_size(); // 4 + constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp); // 2 + + static_assert(NThreads * NPerThread == kNPerBlock); + static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccTDramTileDistribution() + { + constexpr index_t kBlockSize = 256; + + constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size()); + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + + constexpr index_t NumElements = (kMPerBlock * kNPerBlock); + + static_assert(kBlockSize < NumElements); + if constexpr(NumElements < kBlockSize) + { + static_assert(false); + } + else + { + constexpr index_t NThreads = get_warp_size(); // 64 + constexpr index_t NPerThread = kNPerBlock / NThreads; // 1 + + constexpr index_t MThreads = kBlockSize / NThreads; // 4 + constexpr index_t MPerThread = kMPerBlock / MThreads; // 32 + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2>>, + tuple, sequence<0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution() + { + using OaccDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + + constexpr index_t N1 = 16 / sizeof(OaccDataType); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t M2 = get_warp_size() / N0; + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 04825ed514..c523b6dc8e 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -17,7 +17,8 @@ template + index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */, + index_t kMaxSplits_ = 1> struct TileFmhaTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -30,16 +31,7 @@ struct TileFmhaTraits static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; -}; - -template -struct TileFmhaFwdSplitKVCombineTraits -{ - static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; - static constexpr bool kPadHeadDimV = kPadHeadDimV_; - static constexpr index_t kBlockPerCu = kBlockPerCu_; + static constexpr index_t kMaxSplits = kMaxSplits_; }; template