From 1fc6f29f35967cab4ad2157f7b440d5eb400c505 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Wed, 30 Oct 2024 14:03:16 +0800 Subject: [PATCH] [CK_TILE] Add fmha fwd headdim96 support (#1608) * Add ceil_to_qualified_tile_length() * Rename kK0BlockLength to kQKHeaddim * Add kSubQKHeaddim concept to support headdim96 * Fix in math.hpp to avoid using __half interfaces * Add LdsBufferSequence instance for headdim96 * Update in fmha_fwd/fmha_fwd_splitkv codegen to support hd96 testing * Disable hd96 instance generation in codegen fmha_fwd and fmha_fwd_splitkv to save compiling time * Reformat one file * Fix text alignment in fmha_fwd_splitkv.py --------- Co-authored-by: Po Yen Chen [ROCm/composable_kernel commit: 863222181477ff42e809d034428f9160490a63ba] --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 41 +++++++++++------- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 42 ++++++++++++------- include/ck_tile/core/numeric/math.hpp | 12 +++--- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 8 ++-- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 8 ++-- ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 23 +++++----- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 23 +++++----- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 23 +++++----- .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 22 +++++----- .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 23 +++++----- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 15 ++++--- .../ops/fmha/pipeline/tile_fmha_shape.hpp | 20 ++++++++- 12 files changed, 153 insertions(+), 107 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 805803fed4..e5ee1d22e7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -21,6 +21,14 @@ DTYPE_BITS = { "bf8" : 8 } +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + TILE_PARTITIONER_MAP = { "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", @@ -35,7 +43,7 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT FMHA_FWD_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_shape_{F_idx} = ck_tile::TileFmhaShape; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -125,7 +133,7 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_fwd_(s, a); }} """ @@ -142,7 +150,7 @@ class FmhaFwdApiTrait: bk0 : int # tile size along qk gemm unroll bn1 : int # tile size along v head_dim bk1 : int # tile size along kv gemm unroll - bk0blen : int + bk0max : int vlayout : str mask : str bias : str # @@ -156,7 +164,7 @@ class FmhaFwdApiTrait: @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' @property @@ -188,8 +196,9 @@ class FmhaFwdApiTrait: if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False elif self.pipeline_tag in ['qr']: - if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {self.bk0blen} == 0' + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' else: assert False @property @@ -199,8 +208,9 @@ class FmhaFwdApiTrait: if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False elif self.pipeline_tag in ['qr']: - if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {self.bk0blen} == 0' + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' else: assert False @dataclass @@ -271,7 +281,7 @@ class FmhaFwdApiPool: F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) @@ -289,7 +299,7 @@ class FmhaFwdTileSize: F_bk0 : int # tile size along qk gemm unroll F_bn1 : int # tile size along v head_dim F_bk1 : int # tile size along kv gemm unroll - F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) F_rm0 : int # number of warps for gemm0 along q seqlen F_rn0 : int # number of warps for gemm0 along k seqlen F_rk0 : int # number of warps for gemm0 along head dim q (not used) @@ -302,7 +312,7 @@ class FmhaFwdTileSize: F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy @property def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}" + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") @@ -335,7 +345,7 @@ class FmhaFwdKernel: F_bk0 = self.F_tile.F_bk0, F_bn1 = self.F_tile.F_bn1, F_bk1 = self.F_tile.F_bk1, - F_bk0blen = self.F_tile.F_bk0blen, + F_bk0max = self.F_tile.F_bk0max, F_rm0 = self.F_tile.F_rm0, F_rn0 = self.F_tile.F_rn0, F_rk0 = self.F_tile.F_rk0, @@ -382,7 +392,7 @@ class FmhaFwdKernel: bk0=self.F_tile.F_bk0, bn1=self.F_tile.F_bn1, bk1=self.F_tile.F_bk1, - bk0blen=self.F_tile.F_bk0blen, + bk0max=self.F_tile.F_bk0max, vlayout=self.F_pipeline.F_vlayout, mask=self.F_pipeline.F_mask, bias=self.F_pipeline.F_bias, @@ -401,6 +411,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return { '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), + ## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), } @@ -510,4 +521,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 46c26b22cd..b084e9d0fc 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -29,6 +29,14 @@ DTYPE_BITS = { "bf8" : 8 } +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + FMHA_FWD_SPLITKV_PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", @@ -41,7 +49,7 @@ using fmha_mask_{F_idx} = {F_mask}; namespace {{ template struct kernel_runner {{ -using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; +using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_shape = ck_tile::TileFmhaShape; @@ -241,7 +249,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; return fmha_fwd_splitkv_(s, a); @@ -260,7 +268,7 @@ class FmhaFwdSplitKVApiTrait: bk0 : int # tile size along qk gemm unroll bn1 : int # tile size along v head_dim bk1 : int # tile size along kv gemm unroll - bk0blen : int + bk0max : int vlayout : str mask : str bias : str # @@ -270,11 +278,11 @@ class FmhaFwdSplitKVApiTrait: skpad : str dpad : str dvpad : str - pagedkv : str + pagedkv : str @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ f'{self.dvpad}-{self.pagedkv}' @@ -307,8 +315,9 @@ class FmhaFwdSplitKVApiTrait: if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False elif self.pipeline_tag in ['qr']: - if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {self.bk0blen} == 0' + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' else: assert False @property @@ -318,8 +327,9 @@ class FmhaFwdSplitKVApiTrait: if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False elif self.pipeline_tag in ['qr']: - if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {self.bk0blen} == 0' + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' else: assert False @dataclass @@ -414,7 +424,7 @@ class FmhaFwdSplitKVApiPool: F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) @@ -458,7 +468,7 @@ class FmhaFwdSplitKVKernel: F_bk0 = self.F_tile.F_bk0, F_bn1 = self.F_tile.F_bn1, F_bk1 = self.F_tile.F_bk1, - F_bk0blen = self.F_tile.F_bk0blen, + F_bk0max = self.F_tile.F_bk0max, F_rm0 = self.F_tile.F_rm0, F_rn0 = self.F_tile.F_rn0, F_rk0 = self.F_tile.F_rk0, @@ -504,7 +514,7 @@ class FmhaFwdSplitKVKernel: bk0=self.F_tile.F_bk0, bn1=self.F_tile.F_bn1, bk1=self.F_tile.F_bk1, - bk0blen=self.F_tile.F_bk0blen, + bk0max=self.F_tile.F_bk0max, vlayout=self.F_pipeline.F_vlayout, mask=self.F_pipeline.F_mask, bias=self.F_pipeline.F_bias, @@ -559,6 +569,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return { '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1), '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), + ## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), } @@ -576,6 +587,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d return { '32' : FmhaFwdSplitKVCombineTileSize(16, 16, -1), '64' : FmhaFwdSplitKVCombineTileSize(32, 32, -1), + ## '96' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, 128, -1), } @@ -604,7 +616,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if dtype in ['fp16', 'bf16']: for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): # TODO: use async pipeline when compiler is more stable - if hdim == 256 or hdim in [32, 64, 128]: + if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: # if True: pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) @@ -743,4 +755,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im _, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 785691b66f..0faf1aa043 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -1126,7 +1126,7 @@ CK_TILE_DEVICE int8_t neg(int8_t x) template <> CK_TILE_DEVICE fp16_t neg(fp16_t x) { - return __hneg(x); + return -x; }; template @@ -1168,7 +1168,7 @@ CK_TILE_DEVICE double sin(double x) template <> CK_TILE_DEVICE fp16_t sin(fp16_t x) { - return ::hsin(x); + return __ocml_sin_f16(x); }; template @@ -1300,7 +1300,7 @@ CK_TILE_DEVICE double ceil(double x) template <> CK_TILE_DEVICE fp16_t ceil(fp16_t x) { - return ::hceil(x); + return __ocml_ceil_f16(x); }; template @@ -1342,7 +1342,7 @@ CK_TILE_DEVICE double floor(double x) template <> CK_TILE_DEVICE fp16_t floor(fp16_t x) { - return ::hfloor(x); + return __ocml_floor_f16(x); }; template @@ -1365,7 +1365,7 @@ CK_TILE_DEVICE T exp(T x) template <> CK_TILE_DEVICE fp16_t exp(fp16_t x) { - return hexp(x); + return __ocml_exp_f16(x); }; template <> @@ -1389,7 +1389,7 @@ CK_TILE_DEVICE T log(T x) template <> CK_TILE_DEVICE fp16_t log(fp16_t x) { - return hlog(x); + return __ocml_log_f16(x); }; template <> diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 8c1f6c8058..e0c145fde7 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -82,10 +82,10 @@ struct FmhaFwdKernel if (kPadHeadDimV) n += "dv"; return n.empty() ? n : std::string("p") + n; }(); return - _SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s::name) + + _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_" "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + - _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + @@ -657,7 +657,7 @@ struct FmhaFwdKernel { return pad_tensor_view( q_dram_naive, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); } else @@ -724,7 +724,7 @@ struct FmhaFwdKernel [&]() { if constexpr(FmhaPipeline::kQLoadOnce) return make_tuple(number{}, - number{}); + number{}); else return make_tuple(number{}, number{}); }(), diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index ea30025b53..4ffebc3c9c 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -78,10 +78,10 @@ struct FmhaFwdSplitKVKernel if (kPadHeadDimV) n += "dv"; return n.empty() ? n : std::string("p") + n; }(); return - _SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s::name) + + _SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + "_" + (kIsGroupMode ? "group" : "batch") + "_" "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + - _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + @@ -586,7 +586,7 @@ struct FmhaFwdSplitKVKernel { return pad_tensor_view( q_dram_naive, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); } else @@ -735,7 +735,7 @@ struct FmhaFwdSplitKVKernel [&]() { if constexpr(FmhaPipeline::kQLoadOnce) return make_tuple(number{}, - number{}); + number{}); else return make_tuple(number{}, number{}); }(), diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 6e7416ce8e..71c3bd1715 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -34,12 +34,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kN1 = BlockFmhaShape::kN1; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; @@ -75,22 +76,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS return Problem::kBlockPerCu; else { - if constexpr(kK0BlockLength <= 32) + if constexpr(kQKHeaddim <= 32) { return 2; } - else if constexpr(kK0BlockLength <= 64) + else if constexpr(kQKHeaddim <= 64) { return 3; } - else if constexpr(kK0BlockLength <= 128) + else if constexpr(kQKHeaddim <= 128) { if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; } - else if constexpr(kK0BlockLength <= 256) + else if constexpr(kQKHeaddim <= 256) { return 1; } @@ -270,7 +271,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS // prefetch K tile index_t i_total_loops = 0; - constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; static_assert(2 <= k0_loops); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 6837ffdee7..a7e9287143 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -37,12 +37,13 @@ struct BlockFmhaPipelineQRKSVS static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kN1 = BlockFmhaShape::kN1; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; @@ -76,22 +77,22 @@ struct BlockFmhaPipelineQRKSVS return Problem::kBlockPerCu; else { - if constexpr(kK0BlockLength <= 32) + if constexpr(kQKHeaddim <= 32) { return 2; } - else if constexpr(kK0BlockLength <= 64) + else if constexpr(kQKHeaddim <= 64) { return 3; } - else if constexpr(kK0BlockLength <= 128) + else if constexpr(kQKHeaddim <= 128) { if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; } - else if constexpr(kK0BlockLength <= 256) + else if constexpr(kQKHeaddim <= 256) { return 1; } @@ -261,7 +262,7 @@ struct BlockFmhaPipelineQRKSVS // prefetch K tile index_t i_total_loops = 0; - constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; static_assert(2 <= k0_loops); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 05d3dae1cc..10bb01168f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -38,12 +38,13 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kN1 = BlockFmhaShape::kN1; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) @@ -87,7 +88,7 @@ struct BlockFmhaPipelineQRKSVSAsync return 1; } - if constexpr(kK0BlockLength <= 32) + if constexpr(kQKHeaddim <= 32) { if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && FmhaMask::IsMasking) @@ -95,21 +96,21 @@ struct BlockFmhaPipelineQRKSVSAsync else return 2; } - else if constexpr(kK0BlockLength <= 64) + else if constexpr(kQKHeaddim <= 64) { if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 2; else return 3; } - else if constexpr(kK0BlockLength <= 128) + else if constexpr(kQKHeaddim <= 128) { if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; } - else if constexpr(kK0BlockLength <= 256) + else if constexpr(kQKHeaddim <= 256) { return 1; } @@ -339,7 +340,7 @@ struct BlockFmhaPipelineQRKSVSAsync // auto q_tile = q; // tile_elementwise_in(q_element_func, q); index_t i_total_loops = 0; - constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; static_assert(1 <= k0_loops); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index f4767de0e9..a1b1e0e158 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -36,12 +36,12 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kN1 = BlockFmhaShape::kN1; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; @@ -75,22 +75,22 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 return Problem::kBlockPerCu; else { - if constexpr(kK0BlockLength <= 32) + if constexpr(kQKHeaddim <= 32) { return 2; } - else if constexpr(kK0BlockLength <= 64) + else if constexpr(kQKHeaddim <= 64) { return 3; } - else if constexpr(kK0BlockLength <= 128) + else if constexpr(kQKHeaddim <= 128) { if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; } - else if constexpr(kK0BlockLength <= 256) + else if constexpr(kQKHeaddim <= 256) { return 1; } @@ -232,7 +232,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 // prefetch K tile index_t i_total_loops = 0; - constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; static_assert(2 <= k0_loops); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index d08a8d4899..b98247df9c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -36,12 +36,13 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kN1 = BlockFmhaShape::kN1; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; @@ -56,22 +57,22 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS return Problem::kBlockPerCu; else { - if constexpr(kK0BlockLength <= 32) + if constexpr(kQKHeaddim <= 32) { return 2; } - else if constexpr(kK0BlockLength <= 64) + else if constexpr(kQKHeaddim <= 64) { return 3; } - else if constexpr(kK0BlockLength <= 128) + else if constexpr(kQKHeaddim <= 128) { if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; } - else if constexpr(kK0BlockLength <= 256) + else if constexpr(kQKHeaddim <= 256) { return 1; } @@ -235,7 +236,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS // prefetch K tile index_t i_total_loops = 0; - constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; static_assert(2 <= k0_loops); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 807ad65485..fbb05e1641 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -55,7 +55,7 @@ struct BlockFmhaPipelineQXCustomPolicy constexpr index_t MWarp = config.template at<1>(); constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; @@ -323,6 +323,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy struct LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; }; + template<> struct + LdsBufferSequence<3, 3, 3, 4> { using type = sequence<1, 2, 0, 0, 1, 2, 0>; }; + template<> struct LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;}; // clang-format on @@ -332,12 +335,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - constexpr index_t kN0 = BlockFmhaShape::kN0; - constexpr index_t kK0 = BlockFmhaShape::kK0; - constexpr index_t kK1 = BlockFmhaShape::kK1; - constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + constexpr index_t kN0 = BlockFmhaShape::kN0; + constexpr index_t kK0 = BlockFmhaShape::kK0; + constexpr index_t kK1 = BlockFmhaShape::kK1; + constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; - constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; return typename LdsBufferSequence::type{}; diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index f2bb2200fe..570754b22e 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -7,6 +7,20 @@ namespace ck_tile { +static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len) +{ + if(len == 96) + return 128; + if(len == 160) + return 256; + + // only length of 96, 160 and power-of-two is supported + if(!(len & (len - 1))) + return len; + + return 0; +}; + template {}); // tile size along qk gemm unroll static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll - static constexpr index_t kK0BlockLength = + static constexpr index_t kQKHeaddim = BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at // once (or repeately load Q as a whole tile) - static_assert(kK0BlockLength % kK0 == 0, "kK0BlockLength should be divisible by kK0"); + static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0"); + + static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim); // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;