From 273c031ecaf0b8b989972dec846d869f4bd76512 Mon Sep 17 00:00:00 2001 From: ltqin Date: Mon, 7 Jul 2025 16:16:54 +0800 Subject: [PATCH] ck tile pagedkv prefill (#2405) * add prefetching physical block id for pagedkv * start add pagedkv prefill * rename pipeline * add kernel for pagedkv * add an init version pagedkv prefill * fix redefine issue * add struct BlockFmhaFwdPagedKVPipelineProblem and fmha_fwd_pagedkv_args * generate dispatch code * add body generating code * comipling pass * remove dropout from pagedkv * set lse to false in generating code * start changing qr kernel to pagedkv * init version of kernerl with pagedkv * change names of file that are generated * chang host validation for pagedkv prefill * using iglp to change blockgemm * add kernel files to op head file * show parameters * rewrite print parameter fun * add fwd * remove default parameter of GridSize * format * fix nhead issue and add seqlen_k_ptr to batch mode * format code * remove no-longer used code * format * fix some comments --------- Co-authored-by: ltqin Co-authored-by: Po Yen Chen [ROCm/composable_kernel commit: 9f4c5d7372fcec5d4efc0957cb3bc9da7d6e70bc] --- example/ck_tile/01_fmha/CMakeLists.txt | 9 +- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 1 + .../codegen/ops/fmha_pagedkv_prefill.py | 585 +++++++ example/ck_tile/01_fmha/fmha_fwd.cpp | 43 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 260 ++++ include/ck_tile/ops/fmha.hpp | 3 + .../ops/fmha/block/page_block_navigator.hpp | 71 + .../fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 1374 +++++++++++++++++ ...ock_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 751 +++++++++ ...gedkv_pipeline_qr_ks_vs_default_policy.hpp | 91 ++ ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 16 +- .../pipeline/block_fmha_pipeline_problem.hpp | 52 + .../ops/fmha/pipeline/tile_fmha_traits.hpp | 28 + include/ck_tile/ops/gemm.hpp | 1 + .../block/block_gemm_areg_bsmem_creg_v2r1.hpp | 247 +++ 15 files changed, 3520 insertions(+), 12 deletions(-) create mode 100644 example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 4fc8b0b4c9..e73faf6325 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,5 +1,5 @@ # validate user-specified fmha_fwd API list -set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv") +set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill") set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") if(FMHA_FWD_ENABLE_APIS STREQUAL "all") @@ -102,6 +102,13 @@ else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) endif() +# conditionally enable call to the pagedkv_prefill API in fmha_fwd example +if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0) +endif() + # conditionally specify the use of OCP_FP8 if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 5b9d5742b4..9e15a822ef 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -122,6 +122,7 @@ PIPELINE_ENUM_MAP = { "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", + "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py new file mode 100644 index 0000000000..650ebaf80e --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -0,0 +1,585 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + +FMHA_FWD_PAGEDKV_PIPELINE_MAP = { + "qr_pagedkv" : "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS" +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, //lse + {F_pagedkv}, //pagedkv + {F_squant}, + {F_occupancy}, + {F_skip}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdPagedKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdPagedKVKernel; + +using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + +#include + +template<> +float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_pagedkv_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + return fmha_fwd_pagedkv_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + pagedkv : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + skip : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_pagedkv : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_skip : str # true/false + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_skip == 't' : n += '_skip' + else: n += '_nskip' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + + if self.F_pagedkv == 't' : n += '_pagedkv' + else: n += '_npagedkv' + + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_skip = BOOL_MAP[self.F_pipeline.F_skip], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + pagedkv=self.F_pipeline.F_pagedkv, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + # '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } + else: + return None + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + elif dtype in ['fp8', 'bf8']: + # TODO + None + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in FWD_DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + # if pipeline.F_pagedkv == 'f': + # continue + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if hdim == 192 and tile.F_bn1 == 128: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_lse == 't' : + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 972653c218..e9403f4698 100755 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -323,7 +323,8 @@ bool run(const ck_tile::ArgParser& arg_parser) } ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); -#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API +#if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \ + CK_TILE_FMHA_FWD_PAGEDKV_API)) if(0 < page_block_size) { std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" @@ -339,7 +340,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx"); -#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API +#if !(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API) if(use_cache_batch_idx) { std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option" @@ -547,7 +548,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cerr << "num_splits greater than 128 is not supported" << std::endl; return false; } -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(0 < p_drop && (1 < num_splits || use_kvcache)) { std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option" @@ -802,7 +803,7 @@ bool run(const ck_tile::ArgParser& arg_parser) << (is_rotary_interleaved ? "inter" : "half") << ")"; } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(1 < num_splits) { std::cout << ", num_splits:" << num_splits; @@ -843,6 +844,11 @@ bool run(const ck_tile::ArgParser& arg_parser) { traits.has_dropout = (p_drop > 0.0f); } + else if constexpr(std::is_same_v>) + { + traits.use_pagedkv = use_kvcache; + } } }; @@ -1051,6 +1057,17 @@ bool run(const ck_tile::ArgParser& arg_parser) args.split_stride_lse_acc = split_stride_lse_acc; args.split_stride_o_acc = split_stride_o_acc; } + else if constexpr(std::is_same_v>) + { + args.block_table_ptr = + (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + args.is_gappy = false; // use 'false' for flash-attention integration + + args.cache_batch_idx = + (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + } } }; @@ -1072,7 +1089,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const float fwd_ave_time = [&] { #if CK_TILE_FMHA_FWD_SPLITKV_API - if(1 < num_splits || use_kvcache) + if(1 < num_splits && use_kvcache) { fmha_fwd_splitkv_traits fmha_splitkv_traits; init_traits(fmha_splitkv_traits); @@ -1082,6 +1099,18 @@ bool run(const ck_tile::ArgParser& arg_parser) return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); } +#endif +#if CK_TILE_FMHA_FWD_PAGEDKV_API + if(use_kvcache) + { + fmha_fwd_pagedkv_traits fmha_pagedkv_traits; + init_traits(fmha_pagedkv_traits); + + fmha_fwd_pagedkv_args fmha_pagedkv_args; + init_args(fmha_pagedkv_args); + + return fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config); + } #endif fmha_fwd_traits fmha_traits; init_traits(fmha_traits); @@ -1237,7 +1266,7 @@ bool run(const ck_tile::ArgParser& arg_parser) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(0 < page_block_size) { if(i_perm) { k_host_ref.ForEach([&](auto& self, auto i) { @@ -1288,7 +1317,7 @@ bool run(const ck_tile::ArgParser& arg_parser) }); } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(0 < page_block_size) { if(is_v_rowmajor) { if(i_perm) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 15b028fa9f..81dda692ea 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -178,6 +178,86 @@ struct fmha_fwd_args drop_seed_offset; }; +struct fmha_fwd_pagedkv_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* lse_ptr; + void* o_ptr; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not + // nullptr. + + const void* cache_batch_idx; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode: seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k + // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] + // + // batch mode (kvcache): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k_ptr[b] + // group mode (kvcache): + // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // + // when is_gappy=true: + // seqlen_k = kargs.seqlen_k_ptr[b] + // seqstart_k_ptr[b] now store local offset of each batch + // + // when is_gappy=false: + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + ck_tile::index_t min_seqlen_q; +}; + struct fmha_fwd_splitkv_args { const void* q_ptr; @@ -501,6 +581,114 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) } } +template +auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.is_gappy, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_k, + args.batch_stride_v, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.min_seqlen_q); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + }(); + + // FmhaKernel::PrintParameters(kargs, args.batch); + if constexpr(FmhaKernel::kIsGroupMode) + { + dim3 grids = FmhaKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = + FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } +} + template auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) { @@ -870,6 +1058,57 @@ struct fmha_fwd_traits_ template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); +template +struct fmha_fwd_pagedkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kIsPagedKV = kIsPagedKV_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; +}; + +template +float fmha_fwd_pagedkv_(const ck_tile::stream_config&, fmha_fwd_pagedkv_args); + template + CK_TILE_HOST_DEVICE index_t + move_tile_window(index_t /*block_index*/, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step, + index_t /*id*/) const + { + + ck_tile::move_tile_window(tile_window, step); + return 0; + } + + template + CK_TILE_HOST_DEVICE index_t + prefetch_table_id(index_t /*block_index*/, + TileWindow /*tile_window*/, + const typename remove_cvref_t::BottomTensorIndex& /*step*/) const + { + return -1; + } + CK_TILE_HOST_DEVICE static constexpr WindowOrigin to_local_window_origin(const WindowOrigin& global_window_origin) { @@ -153,6 +174,56 @@ struct PageBlockNavigator return new_block_index; } + template + CK_TILE_HOST_DEVICE index_t + move_tile_window(index_t block_index, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step, + index_t id) const + { + ck_tile::move_tile_window(tile_window, step); + + const WindowOrigin global_window_origin = + to_global_window_origin(block_index, tile_window.get_window_origin()); + const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin); + + const index_t new_block_index = get_block_index(global_window_origin); + /// TODO: only update necessary attributes + tile_window.bottom_tensor_view_.desc_ = + (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor(); + tile_window.set_window_origin(local_window_origin); + if(id >= 0) + tile_window.set_bottom_tensor_view_data_ptr(physical_blocks + id * block_stride + + fixed_offset); + else + tile_window.set_bottom_tensor_view_data_ptr(nullptr); + + return new_block_index; + } + + template + CK_TILE_HOST_DEVICE index_t + prefetch_table_id(index_t block_index, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step) const + { + auto local_tile_window = tile_window; // not affect origin window + ck_tile::move_tile_window(local_tile_window, step); + + const WindowOrigin global_window_origin = + to_global_window_origin(block_index, local_tile_window.get_window_origin()); + const index_t new_block_index = get_block_index(global_window_origin); + + if(new_block_index < num_blocks) + { + return physical_block_indices[new_block_index]; + } + else + { + return -1; + } + } + CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const { return block_index == num_blocks - 1; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp new file mode 100644 index 0000000000..e56d518634 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -0,0 +1,1374 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" + +#include +#include +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +// TODO: This class is a variant of the existing FmhaFwdSplitKVKernel pipeline. +// Refactoring to extract shared logic is recommended as future work. +template +struct FmhaFwdPagedKVKernel +{ + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; + static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; + + using AttentionVariant = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using g0br = typename bfs::Gemm0BlockWarps; + using g1br = typename bfs::Gemm1BlockWarps; + using g0wt = typename bfs::Gemm0WarpTile; + using g1wt = typename bfs::Gemm1WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_pagedkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + float scale_s; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + }; + + struct FmhaFwdLogitsSoftCapKargs + { + FmhaFwdLogitsSoftCapKargs() = default; + + void init_logits_soft_cap(float logits_soft_cap_) + { + if(0 < logits_soft_cap_) + { + logits_soft_cap = logits_soft_cap_; + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap = 0.f; + logits_soft_cap_rcp = 0.f; + } + } + + float logits_soft_cap; + float logits_soft_cap_rcp; + }; + + struct FmhaFwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct FmhaFwdBatchModeBiasKargs : FmhaFwdCommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct FmhaFwdAlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + + struct FmhaFwdMaskKargs + { + // ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaFwdFp8StaticQuantKargs + { + float scale_p; + float scale_o; + }; + + struct FmhaFwdCommonLSEKargs + { + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; + }; + + struct FmhaFwdSkipMinSeqlenQKargs + { + ck_tile::index_t min_seqlen_q = 0; + }; + + struct CommonPageBlockTableKargs + { + const int32_t* block_table_ptr; + ck_tile::index_t batch_stride_block_table; + ck_tile::index_t page_block_size; + }; + + struct GroupModePageBlockTableKargs : CommonPageBlockTableKargs + { + bool is_gappy = false; + }; + + struct CacheBatchIdxKargs + { + const int32_t* cache_batch_idx; + }; + + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t, + std::conditional_t> + { + const int32_t* seqlen_k_ptr; + + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for + // single kcache page-block + ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for + // single vcache page-block + ck_tile::index_t batch_stride_o; + }; + + struct FmhaFwdGroupModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + + ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size + // for single kcache page-block + ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size + // for single vcache page-block + }; + + using Kargs = std::conditional_t; + + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + const void* seqlen_k_ptr, // only used for (paged-) kvcache + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + const void* block_table_ptr, + ck_tile::index_t batch_stride_block_table, + ck_tile::index_t page_block_size, + const void* cache_batch_idx, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + {}, // placeholder for pagedkv + {}, // placeholder for logits_soft_cap + reinterpret_cast(seqlen_k_ptr), + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + kargs.scale_o = scale_o; + } + if constexpr(kIsPagedKV) + { + kargs.block_table_ptr = reinterpret_cast(block_table_ptr); + kargs.batch_stride_block_table = batch_stride_block_table; + kargs.page_block_size = page_block_size; + } + else + { + kargs.cache_batch_idx = reinterpret_cast(cache_batch_idx); + } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } + + return kargs; + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + const void* seqlen_k_ptr, // only used for (paged-) kvcache + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + const void* block_table_ptr, + ck_tile::index_t batch_stride_block_table, + ck_tile::index_t page_block_size, + const void* cache_batch_idx, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type) + { + return MakeKargsImpl(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + lse_ptr, + o_ptr, + seqlen_q, + seqlen_k, + seqlen_k_ptr, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + block_table_ptr, + batch_stride_block_table, + page_block_size, + cache_batch_idx, + scale_s, + scale_p, + scale_o, + logits_soft_cap, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_lse, + batch_stride_o, + window_size_left, + window_size_right, + mask_type); + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + const void* block_table_ptr, + ck_tile::index_t batch_stride_block_table, + ck_tile::index_t page_block_size, + bool is_gappy, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_k, // only used for paged-kvcache + ck_tile::index_t batch_stride_v, // only used for paged-kvcache + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + ck_tile::index_t min_seqlen_q) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + {}, // placeholder for logits_soft_cap + {}, // placeholder for pagdkv + {}, // placeholder for min_seqlen_q + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr), + batch_stride_k, + batch_stride_v}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + kargs.scale_o = scale_o; + } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } + if constexpr(kIsPagedKV) + { + kargs.block_table_ptr = reinterpret_cast(block_table_ptr); + kargs.batch_stride_block_table = batch_stride_block_table; + kargs.page_block_size = page_block_size; + kargs.is_gappy = is_gappy; + } + if constexpr(kSkipMinSeqlenQ) + { + kargs.min_seqlen_q = min_seqlen_q; + } + + return kargs; + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + const void* block_table_ptr, + ck_tile::index_t batch_stride_block_table, + ck_tile::index_t page_block_size, + bool is_gappy, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_k, // only used for paged-kvcache + ck_tile::index_t batch_stride_v, // only used for paged-kvcache + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + ck_tile::index_t min_seqlen_q) + { + return MakeKargsImpl(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + lse_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + block_table_ptr, + batch_stride_block_table, + page_block_size, + is_gappy, + scale_s, + scale_p, + scale_o, + logits_soft_cap, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_lse, + nhead_stride_o, + batch_stride_k, + batch_stride_v, + window_size_left, + window_size_right, + mask_type, + min_seqlen_q); + } + + CK_TILE_HOST static void PrintParameters(const Kargs& kargs, int num_batches) + { + static bool dummy = [&]() { + std::cout << std::endl; + + std::cout << " q_ptr: " << kargs.q_ptr << " k_ptr:" << kargs.k_ptr + << " v_ptr: " << kargs.v_ptr << " o_ptr:" << kargs.o_ptr + << " hdim_q: " << kargs.hdim_q << " hdim_v: " << kargs.hdim_v + << " num_head_q:" << kargs.num_head_q + << " nhead_ratio_qk: " << kargs.nhead_ratio_qk << " scale_s:" << kargs.scale_s + << " stride_q:" << kargs.stride_q << " stride_k:" << kargs.stride_k + << " stride_v:" << kargs.stride_v << " stride_o:" << kargs.stride_o + << " nhead_stride_q: " << kargs.nhead_stride_q + << " nhead_stride_k: " << kargs.nhead_stride_k + << " nhead_stride_v:" << kargs.nhead_stride_v + << " nhead_stride_o: " << kargs.nhead_stride_o; + if constexpr(!kIsGroupMode) + { + std::cout << " batch_stride_q:" << kargs.batch_stride_q; + } + std::cout << " batch_stride_k:" << kargs.batch_stride_k + << " batch_stride_v:" << kargs.batch_stride_v; + + if constexpr(kIsGroupMode) + { + if constexpr(kSkipMinSeqlenQ) + { + std::cout << " min_seqlen_q: " << kargs.min_seqlen_q; + } + + std::cout << " seqstart_q_ptr:" << kargs.seqstart_q_ptr + << " seqstart_k_ptr: " << kargs.seqstart_k_ptr + << " seqlen_k_ptr:" << kargs.seqlen_k_ptr; + if(kargs.seqlen_k_ptr != nullptr) + { + std::cout << "{"; + for(int i_batch = 0; i_batch < num_batches; i_batch++) + std::cout << kargs.seqlen_k_ptr[i_batch] << ","; + std::cout << "}"; + } + } + if constexpr(kHasMask) + { + std::cout << " window_size_left: " << kargs.window_size_left + << " window_size_right:" << kargs.window_size_right + << " mask_type: " << static_cast(kargs.mask_type); + } + + if constexpr(kIsPagedKV) + { + std::cout << " block_table_ptr: " << kargs.block_table_ptr + << " batch_stride_block_table:" << kargs.batch_stride_block_table + << " page_block_size: " << kargs.page_block_size; + + std::cout << "table value: ["; + for(int b = 0; b < num_batches; b++) + { + std::cout << "[ "; + for(int i = 0; i < kargs.batch_stride_block_table; i++) + { + std::cout << kargs.block_table_ptr[b * kargs.batch_stride_block_table + i] + << ","; + } + std::cout << " ]"; + } + std::cout << " ]"; + } + std::cout << std::endl; + return true; + }(); + (void)dummy; + } + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_, + bool has_padded_seqlen_k) + { + // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr) + if(has_padded_seqlen_k) + { + // TODO: this may need tuning + return dim3(nhead_, + batch_size_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1)); + } + else + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + nhead_, + batch_size_); + } + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + bool has_padded_seqlen_k = false; + + if constexpr(kIsGroupMode) + has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); + + if(has_padded_seqlen_k) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + else + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + index_t kv_l2p_offset = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + if constexpr(kSkipMinSeqlenQ) + { + if(kargs.seqlen_q <= kargs.min_seqlen_q) + { + return; + } + } + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + + if constexpr(kIsPagedKV) + { + if(kargs.is_gappy) + { + // seqstart_k_ptr has different meaning in this case + kv_l2p_offset = kargs.seqstart_k_ptr[i_batch]; + } + } + } + else + { + const index_t i_cache_batch = [&, i_batch_ = i_batch] { + if constexpr(kIsPagedKV) + { + return i_batch_; + } + else + { + return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_] + : i_batch_); + } + }(); + + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_cache_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_cache_batch) * kargs.batch_stride_v; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto make_k_dram = [&](const KDataType* data, index_t height) { + const auto k_dram_naive = make_naive_tensor_view( + data, // will update this pointer if using paged-kvcache + make_tuple(height, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }; + const auto k_dram = [&]() { + if constexpr(kIsPagedKV) + { + return make_k_dram(nullptr, kargs.page_block_size); + } + else + { + return make_k_dram(k_ptr, kargs.seqlen_k); + } + }(); + + const auto make_v_dram = [&](const VDataType* data, index_t length) { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + data, // will update this pointer if using paged-kvcache + make_tuple(length, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(length)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + data, // will update this pointer if using paged-kvcache + make_tuple(kargs.hdim_v, length), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }; + const auto v_dram = [&]() { + if constexpr(kIsPagedKV) + { + return make_v_dram(nullptr, kargs.page_block_size); + } + else + { + return make_v_dram(v_ptr, kargs.seqlen_k); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_page_block_navigator = + [&, i_batch_ = i_batch, i_nhead_ = i_nhead / kargs.nhead_ratio_qk]() { + if constexpr(kIsPagedKV) + { + const auto* block_indices = + reinterpret_cast(kargs.block_table_ptr) + + i_batch_ * kargs.batch_stride_block_table; + const index_t num_blocks = + integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); + + const long_index_t fixed_offset = i_nhead_ * kargs.nhead_stride_k; + + return make_page_block_navigator( + kargs.k_ptr, + kargs.batch_stride_k, // kcache page-block stride/size + fixed_offset, + block_indices, + num_blocks, + kargs.page_block_size, + k_dram, + make_k_dram(nullptr, + (kv_l2p_offset + kargs.seqlen_k) - + (num_blocks - 1) * kargs.page_block_size)); + } + else + { + return make_page_block_navigator(k_dram); + } + }(); + + auto v_page_block_navigator = + [&, i_batch_ = i_batch, i_nhead_ = i_nhead / kargs.nhead_ratio_qk]() { + if constexpr(kIsPagedKV) + { + const auto* block_indices = + reinterpret_cast(kargs.block_table_ptr) + + i_batch_ * kargs.batch_stride_block_table; + const index_t num_blocks = + integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); + + const long_index_t fixed_offset = i_nhead_ * kargs.nhead_stride_v; + + return make_page_block_navigator( + kargs.v_ptr, + kargs.batch_stride_v, // vcache page-block stride/size + fixed_offset, + block_indices, + num_blocks, + kargs.page_block_size, + v_dram, + make_v_dram(nullptr, + (kv_l2p_offset + kargs.seqlen_k) - + (num_blocks - 1) * kargs.page_block_size)); + } + else + { + return make_page_block_navigator(v_dram); + } + }(); + + auto k_dram_window_lengths = + make_tuple(number{}, number{}); + auto v_dram_window_lengths = + make_tuple(number{}, number{}); + + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + + auto o_acc_tile = [&]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window_lengths, + k_page_block_navigator, + identity{}, // k_element_func + v_dram_window_lengths, + v_page_block_navigator, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{kargs.scale_p}, // p_compute_element_func + composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + kv_l2p_offset, + smem_ptr); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window_lengths, + k_page_block_navigator, + v_dram_window_lengths, + v_page_block_navigator, + bias_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + kv_l2p_offset, + smem_ptr); + } + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp new file mode 100644 index 0000000000..6ad5844b69 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -0,0 +1,751 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// TODO: This class is a variant of the existing BlockFmhaFwdSplitKVPipelineQRKSVS pipeline. +// Refactoring to extract shared logic is recommended as future work. + +template +struct BlockFmhaFwdPagedKVPipelineQRKSVS +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsPagedKV = Problem::kIsPagedKV; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kQKHeaddim <= 32) + { + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + else + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr_pagedkv"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile + const KPageBlockNavigator& k_page_block_navigator, + const KElementFunction& k_element_func, + const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile + const VPageBlockNavigator& v_page_block_navigator, + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate + void* smem_ptr) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowLengths{}[number<0>{}] && + kK0 == KDramBlockWindowLengths{}[number<1>{}] && + kN1 == VDramBlockWindowLengths{}[number<0>{}] && + kK1 == VDramBlockWindowLengths{}[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + + auto q = load_tile(q_dram_window); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [logical_seqlen_k_start, logical_seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + const auto num_total_loop = + integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0); + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -numeric::infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + // k_dram_block_window + const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; + const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset; + // make sure the first tile is completely located in page-block (page-block size should be + // divisible by kN0) + // relationship between each *_start variables: aligned_physical_seqlen_k_start <= + // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start + const index_t aligned_physical_seqlen_k_start = + [&, physical_seqlen_k_start_ = physical_seqlen_k_start] { + if constexpr(kIsPagedKV) + { + return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0); + } + else + { + return physical_seqlen_k_start_; + } + }(); + const index_t num_total_loop = + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); + + auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( + k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), + logical_seqlen_k_start - (physical_seqlen_k_start - + aligned_physical_seqlen_k_start)}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + // v_dram_window + auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( + v_dram_block_window_lengths, + {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + auto q_tile = tile_elementwise_in(q_element_func, q); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + do + { + // STAGE 1, QK gemm + auto k_dram_window = make_tile_window( + k_dram_block_window, + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + auto k_block_tile = load_tile(k_dram_window); + { + // moving k_dram_window is an in-page-block operation, so there is + // no need to invoke k_page_block_navigator.move_tile_window() here. + move_tile_window(k_dram_window, {0, kK0}); + clear_tile(s_acc); // initialize C + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + k_block_tile = load_tile(k_dram_window); + } + auto physical_next_block_id_k = + __builtin_amdgcn_readfirstlane(k_page_block_navigator.prefetch_table_id( + i_page_block_k, k_dram_block_window, {kN0, 0})); + auto physical_next_block_id_v = __builtin_amdgcn_readfirstlane( + v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1})); + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + block_sync_lds(); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_lds_window); + } + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_page_block_navigator.to_global_window_origin( + i_page_block_k, k_dram_block_window.get_window_origin()); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + // position_encoding accept only logical coordinates, do conversion here + position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout(apply_logits_transform, s_acc); +#else + tile_elementwise_inout(apply_logits_transform, s_acc); +#endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_page_block_navigator.to_global_window_origin( + i_page_block_k, k_dram_block_window.get_window_origin()); + // mask accept only logical coordinates, do conversion here + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}) - kv_l2p_offset, + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col - kv_l2p_offset, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_prefetch); + store_tile( + v_lds_window, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch + } + i_page_block_v = v_page_block_navigator.move_tile_window( + i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v); + + const auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&, + &i_page_block_v_ = i_page_block_v, + &v_dram_window_ = v_dram_window](auto i_k1) { + auto physical_next_block_id_v_ = + __builtin_amdgcn_readfirstlane(v_page_block_navigator.prefetch_table_id( + i_page_block_v_, v_dram_window_, {0, kK1})); + const auto v = load_tile(v_dram_window_); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v); + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v)); // store next v + } + i_page_block_v_ = v_page_block_navigator.move_tile_window( + i_page_block_v_, v_dram_window_, {0, kK1}, physical_next_block_id_v_); + }); + } + // move K tile windows + i_page_block_k = k_page_block_navigator.move_tile_window( + i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k); + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + } + } while(++i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } + } +#else + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile + const KPageBlockNavigator& k_page_block_navigator, + const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile + const VPageBlockNavigator& v_page_block_navigator, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate + void* smem_ptr) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_lengths, + k_page_block_navigator, + identity{}, + v_dram_block_window_lengths, + v_page_block_navigator, + identity{}, + bias_dram_block_window_tmp, + identity{}, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + kv_l2p_offset, + smem_ptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp new file mode 100644 index 0000000000..9c348495ff --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +struct BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy + : BlockFmhaPipelineQXKSVSCustomPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + + constexpr auto warp_gemm = []() { + constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + if constexpr(WarpGemmM == 32) + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; + else if constexpr(WarpGemmM == 16) + return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; + else // WarpGemmM == 4 + return WarpGemmMfmaF16F16F32M4N64K16{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + if constexpr(WarpGemmM == 32) + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; + else if constexpr(WarpGemmM == 16) + return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; + else // WarpGemmM == 4 + return WarpGemmMfmaBf16Bf16F32M4N64K16{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + static_assert(WarpGemmM == 32); + + // TODO: hard coded here. Otherwise, it may incorrect result + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } // TODO - bf8_t + }(); + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + + if constexpr(1 < Problem::kNumGemm0Warps) + { + if constexpr(128 >= Problem::BlockFmhaShape::kK0) + return BlockGemmARegBSmemCRegV2R1{}; + else + return BlockGemmARegBSmemCRegV2{}; + } + else + return BlockGemmARegBSmemCRegOneWarpV1{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 7f5f79d7a7..9de640b7cf 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -320,6 +320,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); k_block_tile = load_tile(k_dram_window); } + auto physical_next_block_id_k = + __builtin_amdgcn_readfirstlane(k_page_block_navigator.prefetch_table_id( + i_page_block_k, k_dram_block_window, {kN0, 0})); + auto physical_next_block_id_v = __builtin_amdgcn_readfirstlane( + v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1})); if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -600,8 +605,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS store_tile(v_lds_window, tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch } - i_page_block_v = - v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1}); + i_page_block_v = v_page_block_navigator.move_tile_window( + i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v); const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); @@ -612,6 +617,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static_for<0, k1_loops - 1, 1>{}([&, &i_page_block_v_ = i_page_block_v, &v_dram_window_ = v_dram_window](auto i_k1) { + auto physical_next_block_id_v_ = + __builtin_amdgcn_readfirstlane(v_page_block_navigator.prefetch_table_id( + i_page_block_v_, v_dram_window_, {0, kK1})); const auto v = load_tile(v_dram_window_); // load next v block_sync_lds(); gemm_1(o_acc, @@ -634,12 +642,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS tile_elementwise_in(v_element_func, v)); // store next v } i_page_block_v_ = v_page_block_navigator.move_tile_window( - i_page_block_v_, v_dram_window_, {0, kK1}); + i_page_block_v_, v_dram_window_, {0, kK1}, physical_next_block_id_v_); }); } // move K tile windows i_page_block_k = k_page_block_navigator.move_tile_window( - i_page_block_k, k_dram_block_window, {kN0, 0}); + i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k); // tail { block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 21cc4950eb..20b30b7417 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -61,6 +61,58 @@ struct BlockFmhaPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; +template +struct BlockFmhaFwdPagedKVPipelineProblem +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps; + static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps; + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; + static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ; + static constexpr auto BiasEnum = Traits::BiasEnum; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; + static constexpr bool kIsPagedKV = Traits::kIsPagedKV; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; + template 1 or fwd training is running */ + bool kIsPagedKV_, + bool kDoFp8StaticQuant_, + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */> +struct TileFmhaFwdPagedKVTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadSeqLenK = kPadSeqLenK_; + static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr bool kIsPagedKV = kIsPagedKV_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; +}; + template +struct BlockGemmARegBSmemCRegV2R1 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensors; + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_warp_tensors(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter)); + }); + }); + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = b_warp_tensors(nIter)(kIter); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + static_for<0, KIterPerWarp, 1>{}([&](auto) { + static_for<0, NIterPerWarp, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + }); + }); + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + return make_static_tile_distribution(a_block_dstr_encode); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile