From c0c2ded56684a3a04ad9df1b907d27ae7635067d Mon Sep 17 00:00:00 2001 From: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com> Date: Tue, 12 Aug 2025 13:02:10 +0200 Subject: [PATCH 1/9] fix (#2668) --- example/ck_tile/01_fmha/fmha_fwd.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) mode change 100755 => 100644 example/ck_tile/01_fmha/fmha_fwd.cpp diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp old mode 100755 new mode 100644 index 48306e35fe..c0e4dc3d30 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -525,8 +525,8 @@ bool run(const ck_tile::ArgParser& arg_parser) flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); - num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + - sizeof(ODataType) * real_seqlen_q * hdim_v); + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(ODataType) * real_seqlen_q * hdim_v); num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q + sizeof(VDataType) * hdim_v * real_seqlen_k); } From b7322a521a91fe4762701237f0243dd2c94b7644 Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Tue, 12 Aug 2025 19:43:14 +0800 Subject: [PATCH 2/9] Optimize fmha fwd decode & prefill for gfx950 (#2641) * Fix for fwd/bwd kernel build filter * fix bwd code * save an example for __bf16 type * temp save, waiting for debug * tempsave, fmha_decode * temp save, change all instance to 1wave * fix async copytest bug * Add block_sync_lds_direct_load utility * fix the s_waitcnt_imm calculation * Improve s_waitcnt_imm calculation * fix vmcnt shift * add input validation and bug fix * remove unnecessary output * move test_copy into test * temp save * tempsave * compile pass * tempsave, trload+asyncload done * tempsave. asynccopy+trload sanity checked * remove unnecessary features * fix the lds alignment caused performance regression * enable prefill overload operator(). * remove all lds bankconflict with xor layouts * enable larger tile size; upgrade xor pattern * upgrade prefill pipeline; simple iglp; consistent data produce and consume order * small refactor * Load Q through lds, implement xor; * add vmcnt guard before load ktile * Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA * Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug * add __restrict__ to tr load * merge fa_decode pipeline into fmha_fwd api * remove unnecessary files; rename some files * Remove unnecessary changes * bug fix, clang format; * remove non-necessary change * fix clangformat with 18.1.3 * fix bugs * fix bug * fix bug on non-gfx950 * fix bugs in gemm * fix bug in pki4 * tempsave, update the blocksync functions * change the warp setting for hdim32 fmha fwd * clang format * fix conflict. disable all v-col instance for fmha fwd * Fix the bug * clang format --------- Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 + .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 147 +- example/ck_tile/01_fmha/fmha_fwd.cpp | 2 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 3 + .../ck_tile/01_fmha/script/benchmark_fwd.sh | 11 - .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 21 +- .../core/arch/amd_buffer_addressing.hpp | 17 +- .../arch/amd_buffer_addressing_builtins.hpp | 17 +- include/ck_tile/core/arch/arch.hpp | 27 +- include/ck_tile/core/arch/utility.hpp | 15 + include/ck_tile/core/config.hpp | 10 + include/ck_tile/core/numeric/bfloat16.hpp | 11 + include/ck_tile/core/numeric/pk_fp4.hpp | 2 +- include/ck_tile/core/numeric/pk_int4.hpp | 2 +- include/ck_tile/core/numeric/vector_type.hpp | 12 +- .../unary_element_wise_operation.hpp | 7 - include/ck_tile/ops/fmha.hpp | 2 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 1504 ++++++++++++----- ...block_fmha_bwd_pipeline_default_policy.hpp | 24 +- .../pipeline/block_fmha_pipeline_enum.hpp | 7 + .../pipeline/block_fmha_pipeline_problem.hpp | 2 + ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 1177 +++++++++++++ ..._pipeline_qr_ks_vs_async_trload_policy.hpp | 823 +++++++++ ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 30 +- .../block/block_gemm_areg_breg_creg_v1.hpp | 180 +- .../ops/gemm/block/block_gemm_problem.hpp | 9 +- .../gemm_pipeline_ag_bg_cr_scheduler.hpp | 6 + .../gemm/pipeline/gemm_pipeline_problem.hpp | 48 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 8 + .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 4 + .../ck_tile/ops/reduce/block/block_reduce.hpp | 30 +- 31 files changed, 3533 insertions(+), 627 deletions(-) create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 6fca800c90..42a9d5148a 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -115,6 +115,7 @@ PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", + "qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", } PIPELINE_ENUM_MAP = { @@ -123,6 +124,7 @@ PIPELINE_ENUM_MAP = { "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 269af4e6a7..ce35c6a2a7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -12,6 +12,7 @@ from typing import List, Optional, Tuple from codegen.cmake_config import * from codegen.cpp_symbol_map import * +from codegen.utils import update_file DTYPE_BITS = { @@ -83,6 +84,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< {F_mode}, fmha_variant_{F_idx}, fmha_mask_{F_idx}, + {F_trload}, fmha_trait_{F_idx}>; using fmha_pipeline_{F_idx} = {F_pipeline}< @@ -97,7 +99,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; #include @@ -161,12 +163,19 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); {F_dispatch} return r; }} """ +FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} @@ -177,8 +186,8 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < """ FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; return fmha_fwd_(s, a); }} """ @@ -221,6 +230,7 @@ class FmhaFwdApiTrait: dpad : str dvpad : str skip : str + tr_load : str constraint : CppConstraint @property @@ -231,13 +241,19 @@ class FmhaFwdApiTrait: @property def scheck(self) -> str: if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag in ['qr_async', 'qr_async_trload']: if self.spad == 't' : return 'true' # always support else : return 'true' elif self.pipeline_tag in ['qr', 'qs']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False + + @property + def seqtune(self) -> str: + if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true + else: + return f'a.seqlen_q <= {self.bm0}' @property def skcheck(self) -> str: @@ -248,6 +264,9 @@ class FmhaFwdApiTrait: elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag == 'qr_async_trload': + if self.skpad == 't' : return 'true' + else: return 'true' else: assert False @property @@ -256,7 +275,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qs']: + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {bk0submax} == 0' @@ -268,7 +287,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qs']: + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {bk0submax} == 0' @@ -290,6 +309,7 @@ class FmhaFwdPipeline: F_squant : str # F_mask : str # value from MASK_MAP F_skip : str # true/false + F_trload : str # true/false F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -331,6 +351,9 @@ class FmhaFwdPipeline: if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' + + if self.F_trload == 't' : n += '_trload' + else: n += '_ntrload' return n @@ -351,31 +374,39 @@ class FmhaFwdApiPool: @property def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][(hdim, hdim_v)] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: + tr_load_cond_map = { + "t": "has_load_tr", + "f": "true" + } + + per_tr_load =str() + for tr_load in ["t", "f"]: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][(hdim, hdim_v)] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes) + if not per_tr_load: # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + per_tr_load += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load) @dataclass class FmhaFwdTileSize: @@ -458,7 +489,8 @@ class FmhaFwdKernel: F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_trload = BOOL_MAP[self.F_pipeline.F_trload]) @property def name(self) -> str: @@ -494,6 +526,7 @@ class FmhaFwdKernel: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, skip=self.F_pipeline.F_skip, + tr_load=self.F_pipeline.F_trload, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) class KernelComponentFactory: @@ -503,10 +536,15 @@ class KernelComponentFactory: def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], @@ -534,34 +572,27 @@ class KernelComponentFactory: if dtype in ['fp16', 'bf16']: for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256 and hdim_v == 256: - # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": + pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) + pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -599,6 +630,12 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue + if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + continue + if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + continue # logits_soft_cap is only allowed if no bias if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): continue @@ -665,10 +702,10 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl return (api_pool, gen) def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) + update_file(autogen_dir / kernel.filename, kernel.template) def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index c0e4dc3d30..d0f8e3798c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1135,7 +1135,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec - << " GB/s" << std::flush; + << " GB/s" << std::flush << std::endl; if(do_validation == 0) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 81dda692ea..df1e9e5699 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/host/device_prop.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/fmha.hpp" @@ -1028,6 +1029,7 @@ template struct fmha_fwd_traits_ { @@ -1052,6 +1054,7 @@ struct fmha_fwd_traits_ static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; + static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; }; diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 599c595a75..88c16cceb6 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -18,14 +18,3 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn done done done - -for perm in 0 1 ; do - -$EXE -prec=fp8 -squant=1 -b=32 -h=16 -d=128 -s=512 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=16 -h=16 -d=128 -s=1024 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=8 -h=16 -d=128 -s=2048 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=4 -h=16 -d=128 -s=4096 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=2 -h=16 -d=128 -s=8192 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=1 -h=16 -d=128 -s=16384 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 - -done \ No newline at end of file diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index b867cd6c07..dc2be933bd 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -42,7 +42,6 @@ run_fp16_bf16_tests() { for prec in "fp16" "bf16" ; do for mode in 1 0 ; do for perm in 0 1 ; do - for vlayout in "r" "c" ; do for hdim in 32 64 128 256 ; do for lse in 0 1 ; do for bias in "n" "e" "a" ; do @@ -51,16 +50,16 @@ run_fp16_bf16_tests() { for page_block_size in $PAGE_BLOCK_SIZE ; do for cache_batch_idx in $CACHE_BATCH_IDX ; do - # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done ; done ; done diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 35da19cd3e..07be65a150 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -41,10 +41,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); - r.x = __builtin_amdgcn_readfirstlane(r.x); - r.y = __builtin_amdgcn_readfirstlane(r.y); - r.z = __builtin_amdgcn_readfirstlane(r.z); - r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } @@ -1318,6 +1314,17 @@ enum struct amd_buffer_coherence_enum glc = 1, slc = 2, glc_slc = 3, + // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 + // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system + // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse + WAVE_NT0 = 0, + WAVE_NT1 = 2, + GROUP_NT0 = 1, + GROUP_NT1 = 3, + DEVICE_NT0 = 8, + DEVICE_NT1 = 10, + SYSTEM_NT0 = 9, + SYSTEM_NT1 = 11, }; template & src_thread_ #if defined(__gfx950__) template -__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) +__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) { static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 8c3bc0bc36..c64b296408 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -32,10 +32,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); - r.x = __builtin_amdgcn_readfirstlane(r.x); - r.y = __builtin_amdgcn_readfirstlane(r.y); - r.z = __builtin_amdgcn_readfirstlane(r.z); - r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } @@ -1186,6 +1182,17 @@ enum struct amd_buffer_coherence_enum glc = 1, slc = 2, glc_slc = 3, + // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 + // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system + // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse + WAVE_NT0 = 0, + WAVE_NT1 = 2, + GROUP_NT0 = 1, + GROUP_NT1 = 3, + DEVICE_NT0 = 8, + DEVICE_NT1 = 10, + SYSTEM_NT0 = 9, + SYSTEM_NT1 = 11, }; template -__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) +__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) { static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index ab42ec8617..f0e9518120 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -89,21 +89,6 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; } CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; } -CK_TILE_DEVICE void block_sync_lds() -{ -#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM - // asm volatile("\ - // s_waitcnt lgkmcnt(0) \n \ - // s_barrier \ - // " ::); - - __builtin_amdgcn_s_waitcnt(0xc07f); - __builtin_amdgcn_s_barrier(); -#else - __syncthreads(); -#endif -} - CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0) { #ifdef __gfx12__ @@ -174,6 +159,18 @@ CK_TILE_DEVICE void s_waitcnt_barrier() __builtin_amdgcn_s_barrier(); } +template +CK_TILE_DEVICE void block_sync_lds() +{ + s_waitcnt_barrier(); +} + +template +CK_TILE_DEVICE void block_sync_lds_direct_load() +{ + s_waitcnt_barrier(); +} + CK_TILE_DEVICE void s_nop(index_t cnt = 0) { #if 1 diff --git a/include/ck_tile/core/arch/utility.hpp b/include/ck_tile/core/arch/utility.hpp index 7184f99521..93008f8525 100644 --- a/include/ck_tile/core/arch/utility.hpp +++ b/include/ck_tile/core/arch/utility.hpp @@ -59,6 +59,21 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta) #endif } +template +CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local) +{ + static_assert(sizeof(T) == sizeof(int32_t), "wrong!"); + + const int32x2_t x = __builtin_amdgcn_permlane32_swap( + bit_cast(v_local), bit_cast(v_local), false, false); + + thread_buffer v; + v(0) = bit_cast(x[0]); + v(1) = bit_cast(x[1]); + + return v; +} + template CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane) { diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index c471f416c3..e472bd01e5 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -191,6 +191,16 @@ #endif #endif +// use llvm builtin bf16 data type after ROCm 6.5 +#ifndef CK_TILE_USE_LLVM_BUILTIN_BF16 +#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \ + (HIP_VERSION_MAJOR >= 7) +#define CK_TILE_USE_LLVM_BUILTIN_BF16 1 +#else +#define CK_TILE_USE_LLVM_BUILTIN_BF16 0 +#endif +#endif + #ifndef CK_TILE_DEBUG_LOG #define CK_TILE_DEBUG_LOG 0 #endif diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 6f31468809..245fb7244f 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -6,6 +6,9 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#if CK_TILE_USE_LLVM_BUILTIN_BF16 +#include +#endif #include #pragma once @@ -102,7 +105,11 @@ struct native_t using bf16_t = bfloat16_t; using bf16_raw_t = typename bf16_t::raw_type; #else +#if CK_TILE_USE_LLVM_BUILTIN_BF16 +using bfloat16_t = __bf16; +#else using bfloat16_t = ushort; +#endif using bf16_t = bfloat16_t; using bf16_raw_t = uint16_t; #endif @@ -280,7 +287,11 @@ template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant = {}) { +#if defined(__gfx950__) + return static_cast(f); +#else return bit_cast(float_to_bf16_raw(f, constant{})); +#endif } template using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); -using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); +using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) { diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 58bdb43b08..bbd3d53827 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -131,12 +131,12 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64))); // bf16 // using bf16_t = ... -using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); -using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4))); -using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8))); -using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16))); -using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32))); -using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64))); +using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); +using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4))); +using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8))); +using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16))); +using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32))); +using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64))); // i32 // using int32_t = ... diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 0e385901ed..b69c167315 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -330,13 +330,6 @@ struct PassThrough y = type_convert(x); } - template <> - CK_TILE_HOST_DEVICE void - operator()(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const - { - y = type_convert(x); - } - template <> CK_TILE_HOST_DEVICE void operator()(float& y, const ck_tile::fp16_t& x) const diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index d8dd5db12e..69f645b850 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -52,6 +52,8 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 8d257a3329..5b3d38d3e7 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -13,6 +13,7 @@ #include #include +#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] @@ -61,6 +62,14 @@ struct FmhaFwdKernel static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; + static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad; +#if defined(__gfx950__) + static constexpr bool kIsAvialable = true; +#else + static constexpr bool kIsAvialable = !kUseTrLoad; +#endif + static constexpr std::string_view kPipelineName = FmhaPipeline::name; + // clang-format off template struct t2s; template <> struct t2s { static constexpr const char * name = "fp32"; }; @@ -100,7 +109,7 @@ struct FmhaFwdKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload"); #undef _SS_ #undef _TS_ // clang-format on @@ -1036,455 +1045,1142 @@ struct FmhaFwdKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { - // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; + if constexpr(kIsAvialable) + run_(std::move(kargs)); + } - // divide problem - const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); - const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; - - if constexpr(kIsGroupMode) + CK_TILE_DEVICE void run_(Kargs kargs) const + { + if constexpr(kPipelineName != "qr_async_trload") { - // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - if constexpr(std::is_same_v) - { - batch_offset_v = key_start * kargs.stride_v; - } - else - { - batch_offset_v = key_start; - } - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = query_start * kargs.stride_bias; - } - if constexpr(kStoreLSE) - { - batch_offset_lse = query_start; - } - if constexpr(kHasDropout) - { - batch_offset_randval = query_start * kargs.stride_randval; - } - batch_offset_o = query_start * kargs.stride_o; + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - if constexpr(kSkipMinSeqlenQ) + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsGroupMode) { - if(kargs.seqlen_q <= kargs.min_seqlen_q) + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + if constexpr(kSkipMinSeqlenQ) + { + if(kargs.seqlen_q <= kargs.min_seqlen_q) + { + return; + } + } + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) { return; } - } - // # of required blocks is different in each groups, terminate unnecessary blocks - // earlier - if(kargs.seqlen_q <= i_m0) - { - return; - } - - if(kargs.seqlen_k_ptr != nullptr) - { - kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } } else { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = + static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } - } - else - { - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; - } - if constexpr(kStoreLSE) - { - batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; - } - if constexpr(kHasDropout) - { - batch_offset_randval = - static_cast(i_batch) * kargs.batch_stride_randval; - } - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; - } - // for simplicity, batch stride we just modify the pointer - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_q + - batch_offset_q; - const KDataType* k_ptr = - reinterpret_cast(kargs.k_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + - batch_offset_k; - const VDataType* v_ptr = - reinterpret_cast(kargs.v_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + - batch_offset_v; - ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_o + - batch_offset_o; + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; - // Q/K/V DRAM and DRAM window - const auto q_dram = [&]() { - const auto q_dram_naive = make_naive_tensor_view( - q_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_q, 1), - number{}, - number<1>{}); - if constexpr(FmhaPipeline::kQLoadOnce) - { - return pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - const auto k_dram = [&]() { - const auto k_dram_naive = make_naive_tensor_view( - k_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), - number{}, - number<1>{}); - - constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; - return pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - const auto v_dram = [&]() { - if constexpr(std::is_same_v) - { - const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_v), - make_tuple(kargs.stride_v, 1), - number{}, + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view(q_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, number<1>{}); - - const auto v_dram_transposed = - transform_tensor_view(v_dram_naive, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_k)), - make_tuple(sequence<1>{}, sequence<0>{}), - make_tuple(sequence<0>{}, sequence<1>{})); constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; return pad_tensor_view( - v_dram_transposed, - make_tuple(number{}, number{}), - sequence{}); + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = transform_tensor_view( + v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_k), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, + make_tuple(number{}, number{}), + {0, 0}); + + auto v_dram_window = make_tile_window( + v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables are not supported. + /// Remove following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = + make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = + make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch_, + i_nhead_, + kargs.num_head_q, + kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val + : *kargs.drop_seed.ptr, + kargs.is_drop_seed_offset_from_host + ? kargs.drop_offset.val + : *kargs.drop_offset.ptr, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + kargs.is_store_randval}; + } + else + { + return NullBlockDropout{}; + }; + }(); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + + auto o_acc_tile = [&]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{kargs.scale_p}, // p_compute_element_func + composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = make_tile_window( + o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } + else + { + // TODO: Refine the logical here. + // In Decode case + // 1. we don't expect KV data reused by different ThreadGroups, bypass the cache + // 2. limit the LDS usage, as we want higher occupancy + // In Prefill case + // 1. we expect KV data reused by different ThreadGroups, use cache + // 2. use more LDS, as we want better memory latency hiding + // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the + // cache + constexpr bool PrefillCase = FmhaPipeline::kM0 >= 128; + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = i_tile_m * FmhaPipeline::kM0; + const index_t i_n1 = i_tile_n * FmhaPipeline::kN1; + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; // unused for paged-kvcache + long_index_t batch_offset_v = 0; // unused for paged-kvcache + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + // index_t kv_l2p_offset = + // 0; // logical-to-physical offset of seqlen_k coordinate. only used for + // paged-kvcache + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + + batch_offset_lse = query_start; + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + kargs.seqlen_k = + kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; + } } else { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = + static_cast(i_batch) * kargs.batch_stride_bias; + } + } + + // for simplicity, batch stride we just modify the pointer + const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk; + + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead_k) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead_k) * kargs.nhead_stride_v + + batch_offset_v; + + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&] { + const auto q_dram_naive = [&] { + { + return make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + } + }(); + + if constexpr(FmhaPipeline::kQLoadOnce) + { + const auto seqlen_q = kargs.seqlen_q; + const auto q_dram_pad = pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); +#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + constexpr index_t LDSLayerSize = 256 / sizeof(QDataType); + constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); + + if constexpr(XorLengthFold > 1) + { + const auto q_dram_unmerged = transform_tensor_view( + q_dram_pad, + make_tuple( + make_unmerge_transform( + make_tuple(seqlen_q / XorLengthFold, XorLengthFold)), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + const auto q_dram_merged = transform_tensor_view( + q_dram_unmerged, + make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold), + make_merge_transform_v3_division_mod(make_tuple( + XorLengthFold, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto q_dram_unmerged_xor = transform_tensor_view( + q_dram_merged, + make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold), + make_unmerge_transform(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + const auto q_dram_permuted = transform_tensor_view( + q_dram_unmerged_xor, + make_tuple( + make_xor_transform( + make_tuple(seqlen_q / XorLengthFold, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + const auto q_dram_tmp = transform_tensor_view( + q_dram_permuted, + make_tuple( + make_pass_through_transform(seqlen_q / XorLengthFold), + make_unmerge_transform( + make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_view( + q_dram_tmp, + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(seqlen_q / XorLengthFold, number{})), + make_merge_transform_v3_division_mod(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else +#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + { + const auto q_dram_unmerged = transform_tensor_view( + q_dram_pad, + make_tuple( + make_pass_through_transform(seqlen_q), + make_unmerge_transform(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + const auto q_dram_permuted = transform_tensor_view( + q_dram_unmerged, + make_tuple( + make_xor_transform(make_tuple(seqlen_q, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + return transform_tensor_view( + q_dram_permuted, + make_tuple( + make_pass_through_transform(seqlen_q), + make_merge_transform_v3_division_mod(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto make_k_dram = [&](const KDataType* data, index_t height) { + const auto k_dram_naive = make_naive_tensor_view( + data, // will update this pointer if using paged-kvcache + make_tuple(height, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + const auto k_dram_pad = pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + +#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + constexpr index_t LDSLayerSize = 256 / sizeof(KDataType); + constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); + + if constexpr(XorLengthFold > 1) + { + const auto k_dram_unmerged = transform_tensor_view( + k_dram_pad, + make_tuple(make_unmerge_transform( + make_tuple(height / XorLengthFold, XorLengthFold)), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + const auto k_dram_merged = transform_tensor_view( + k_dram_unmerged, + make_tuple(make_pass_through_transform(height / XorLengthFold), + make_merge_transform_v3_division_mod(make_tuple( + XorLengthFold, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto k_dram_unmerged_xor = transform_tensor_view( + k_dram_merged, + make_tuple(make_pass_through_transform(height / XorLengthFold), + make_unmerge_transform(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + const auto k_dram_permuted = transform_tensor_view( + k_dram_unmerged_xor, + make_tuple( + make_xor_transform( + make_tuple(height / XorLengthFold, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + const auto k_dram_tmp = transform_tensor_view( + k_dram_permuted, + make_tuple( + make_pass_through_transform(height / XorLengthFold), + make_unmerge_transform(make_tuple( + number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_view( + k_dram_tmp, + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(height / XorLengthFold, number{})), + make_merge_transform_v3_division_mod(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else +#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + { + const auto k_dram_unmerged = transform_tensor_view( + k_dram_pad, + make_tuple( + make_pass_through_transform(height), + make_unmerge_transform(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + const auto k_dram_permuted = transform_tensor_view( + k_dram_unmerged, + make_tuple( + make_xor_transform(make_tuple( + height, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + return transform_tensor_view( + k_dram_permuted, + make_tuple( + make_pass_through_transform(height), + make_merge_transform_v3_division_mod(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + }; + const auto k_dram = [&]() { + { + return make_k_dram(k_ptr, kargs.seqlen_k); + } + }(); + + const auto make_v_dram = [&](const VDataType* data, index_t length) { const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.hdim_v, kargs.seqlen_k), - make_tuple(kargs.stride_v, 1), + data, // will update this pointer if using paged-kvcache + make_tuple(length, kargs.hdim_v), + make_tuple(kargs.hdim_v, 1), number{}, number<1>{}); - constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; - return pad_tensor_view( + // TODO: Add kVHeadDim + constexpr index_t XorGroupSize = + FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); + + const auto v_dram_pad = pad_tensor_view( v_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); + make_tuple(number{}, number{}), + sequence{}); - auto q_dram_window = make_tile_window( - q_dram, - [&]() { - if constexpr(FmhaPipeline::kQLoadOnce) - return make_tuple(number{}, - number{}); +#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + constexpr index_t LDSLayerSize = 256 / sizeof(VDataType); + constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); + + if constexpr(XorLengthFold > 1) + { + const auto v_dram_unmerged = transform_tensor_view( + v_dram_pad, + make_tuple(make_unmerge_transform( + make_tuple(length / XorLengthFold, XorLengthFold)), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + const auto v_dram_merged = transform_tensor_view( + v_dram_unmerged, + make_tuple(make_pass_through_transform(length / XorLengthFold), + make_merge_transform_v3_division_mod(make_tuple( + XorLengthFold, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto v_dram_unmerged_xor = transform_tensor_view( + v_dram_merged, + make_tuple( + make_pass_through_transform(length / XorLengthFold), + make_unmerge_transform(make_tuple(number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + const auto v_dram_permuted = transform_tensor_view( + v_dram_unmerged_xor, + make_tuple( + make_xor_transform(make_tuple(length / XorLengthFold, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + const auto v_dram_tmp = transform_tensor_view( + v_dram_permuted, + make_tuple(make_pass_through_transform(length / XorLengthFold), + make_unmerge_transform(make_tuple( + number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_view( + v_dram_tmp, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(length / XorLengthFold, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } else - return make_tuple(number{}, number{}); - }(), - {i_m0, 0}); +#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + { + const auto v_dram_unmerged = transform_tensor_view( + v_dram_pad, + make_tuple(make_pass_through_transform(length), + make_unmerge_transform( + make_tuple(number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); - auto k_dram_window = make_tile_window( - k_dram, make_tuple(number{}, number{}), {0, 0}); + const auto v_dram_permuted = transform_tensor_view( + v_dram_unmerged, + make_tuple(make_xor_transform(make_tuple( + length, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); - auto v_dram_window = - make_tile_window(v_dram, - make_tuple(number{}, number{}), - {i_n1, 0}); - /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove - /// following copy capture of the 'i_nhead' if in C++20 - const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto bias_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - const BiasDataType* bias_ptr = - reinterpret_cast(kargs.bias_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_bias + - batch_offset_bias; - - const auto bias_dram = [&]() { - const auto bias_dram_naive = make_naive_tensor_view( - bias_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_bias, 1), - number{}, - number<1>{}); - - return pad_tensor_view(bias_dram_naive, - bias_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(bias_dram_window_lengths); - } - }(); - - // lse - auto lse_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto lse_dram_window_lengths = make_tuple(number{}); - if constexpr(kStoreLSE) - { - LSEDataType* lse_ptr = - reinterpret_cast(kargs.lse_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; - - const auto lse_dram = [&]() { - const auto lse_dram_naive = make_naive_tensor_view( - lse_ptr, - make_tuple(kargs.seqlen_q), - make_tuple(1), - number<1>{}, - number<1>{}); - - return pad_tensor_view( - lse_dram_naive, lse_dram_window_lengths, sequence{}); - }(); - - return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); - } - else - { - return make_null_tile_window(lse_dram_window_lengths); - } - }(); - - auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { - if constexpr(kHasDropout) - { - return BlockDropout{i_batch_, - i_nhead_, - kargs.num_head_q, - kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val - : *kargs.drop_seed.ptr, - kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val - : *kargs.drop_offset.ptr, - kargs.rp_undrop, - kargs.p_undrop_in_uint8_t, - kargs.is_store_randval}; - } - else - { - return NullBlockDropout{}; + return transform_tensor_view( + v_dram_permuted, + make_tuple(make_pass_through_transform(length), + make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } }; - }(); - auto randval_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto randval_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(kHasDropout) - { - RandValOutputDataType* rand_val_ptr = - reinterpret_cast(kargs.rand_val_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_randval + - batch_offset_randval; - - const auto randval_dram = [&]() { - const auto randval_dram_naive = - make_naive_tensor_view( - rand_val_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_randval, 1), - number<1>{}, - number<1>{}); - - return pad_tensor_view(randval_dram_naive, - randval_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(randval_dram_window_lengths); - } - }(); - - FmhaMask mask = [&]() { - if constexpr(kHasMask) - return ck_tile::make_generic_attention_mask_from_lr_window( - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); - else - return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; - }(); - - // WA i_batch capture structure binding before c++20 - auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - // data loading, shared by entire wg - // TODO: how to use s_read? - SaccDataType slope = - *(reinterpret_cast(kargs.alibi_slope_ptr) + - i_batch_ * kargs.alibi_slope_stride + i_nhead_); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; -#endif - if constexpr(kHasMask) + const auto v_dram = [&]() { { - return make_alibi_from_lr_mask(slope, - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type); + return make_v_dram(v_ptr, kargs.seqlen_k); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, + make_tuple(number{}, number{}), + {0, 0}); + + auto v_dram_window = make_tile_window( + v_dram, + make_tuple(number{}, number{}), + {0, 0}); + + /// FIXME: Before C++20, capturing structured binding variables are not supported. + /// Remove following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = + make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); } else { - return Alibi{ - slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + return make_null_tile_window(bias_dram_window_lengths); } - } - else - { - return EmptyPositionEncoding{}; - } - }(); + }(); - AttentionVariant variant; - const auto variant_params = [&] { - if constexpr(kHasLogitsSoftCap) - { - return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; - } - else - { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; - } - }(); + // lse acc + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + + batch_offset_lse; - BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + const auto lse_dram = [&] { + const auto lse_dram_naive = [&] { + { + return make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + } + }(); + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); - auto o_acc_tile = [&]() { - if constexpr(kDoFp8StaticQuant) - { - return FmhaPipeline{}( - q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - randval_dram_window, - lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{kargs.scale_p}, // p_compute_element_func - composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); - } - else - { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - randval_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); - } - }(); + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); - // O DRAM and O DRAM window - auto o_dram = [&]() { - const auto o_dram_naive = make_naive_tensor_view( - o_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.stride_o, 1), - number{}, - number<1>{}); + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); - return pad_tensor_view( - o_dram_naive, + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask( + slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + auto o_acc_tile = [&]() { + if constexpr(PrefillCase) + { + // allocate double lds + // add __restrict__ here to avoid aliasing + __shared__ char smem_ptrk0 + [FmhaPipeline::Policy::template GetSmemSizeK()]; + __shared__ char smem_ptrk1 + [FmhaPipeline::Policy::template GetSmemSizeK()]; + __shared__ char smem_ptrv0[FmhaPipeline::Policy::template GetSmemSizeV< + typename FmhaPipeline::Problem>()]; + __shared__ char smem_ptrv1[FmhaPipeline::Policy::template GetSmemSizeV< + typename FmhaPipeline::Problem>()]; + + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + smem_ptrk0, + smem_ptrk1, + smem_ptrv0, + smem_ptrv1); + } + else + { + __shared__ char smem_ptr[GetSmemSize()]; + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + smem_ptr); + } + }(); + + // Oacc DRAM and Oacc DRAM window + auto o_dram = [&] { + const auto o_dram_naive = [&] { + { + return make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + } + }(); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = make_tile_window( + o_dram, make_tuple(number{}, number{}), - sequence{}); - }(); + {i_m0, i_n1}); - auto o_dram_window = - make_tile_window(o_dram, - make_tuple(number{}, number{}), - {i_m0, i_n1}); - - EpiloguePipeline{}(o_dram_window, o_acc_tile); + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index aa2ec99590..f6a20c5cb5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1038,7 +1038,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1096,7 +1096,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1190,7 +1190,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1249,7 +1249,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1344,7 +1344,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1379,7 +1379,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1490,7 +1490,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1589,7 +1589,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1623,7 +1623,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1667,7 +1667,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1718,7 +1718,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); pt_out.set_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), pt_warp_tensor.get_thread_buffer()); }); @@ -1768,7 +1768,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); dst_out.set_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), dst_warp_tensor.get_thread_buffer()); }); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index cf70dff63f..45a1c8f4b8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -11,6 +11,7 @@ enum class BlockFmhaPipelineEnum QRKSVS = 0, QRKSVS_ASYNC, QSKSVS, + QRKSVS_ASYNC_TRLOAD, }; template @@ -32,4 +33,10 @@ struct BlockFmhaPipelineEnumToStr static constexpr const char* name = "qs"; }; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr_async_trload"; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 20b30b7417..86ac713b6f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -22,6 +22,7 @@ template struct BlockFmhaPipelineProblem { @@ -46,6 +47,7 @@ struct BlockFmhaPipelineProblem static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kUseTrLoad = kUseTrLoad_; // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp new file mode 100644 index 0000000000..39d8814692 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -0,0 +1,1177 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQRKSVSAsyncTrload +{ + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr index_t kNWarp = BlockFmhaShape::Gemm0BlockWarps::at(I1); + static constexpr index_t kNXdl = BlockFmhaShape::Gemm0WarpTile::at(I1); + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + // static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + // Problem::kPadHeadDimV == true); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = + Problem::kPadHeadDimQ; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = + Problem::kPadHeadDimV; // support multiple of vector(like 8x) + + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasUnevenSplits = true; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentOacc = Policy::template GetAlignmentO(); + + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kQKHeaddim <= 32) + { + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || kM0 >= 256) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + else + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr_async_trload"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + // Decode + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[I0] && + kSubQKHeaddim == QDramBlockWindowTmp{}.get_window_lengths()[I1] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[I0] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[I1] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[I0] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[I1] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I0] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I1], + "wrong!"); + ignore = bias_dram_block_window_tmp; + ignore = position_encoding; + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + auto o_acc = OaccBlockTileType{}; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(o_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + // init M, L + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + const auto q_origin = q_dram_block_window_tmp.get_window_origin(); + const auto [logical_seqlen_k_start, logical_seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(I0), number{}, number{}); + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) + { + const index_t logical_num_total_loop = + integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0); + if(logical_num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse_acc = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse_acc, -numeric::infinity()); + + if(get_thread_local_1d_id() < kM0) + { + store_tile(lse_acc_dram_window_tmp, lse_acc); + } + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + // Q tile in LDS + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution()); + + auto q_lds_write_view = make_tensor_view( + static_cast(smem_ptr), Policy::template MakeQLdsBlockDescriptor()); + + auto q_lds_read_view = make_tensor_view( + static_cast(smem_ptr), + Policy::template MakeQLdsBlockDescriptor()); + + auto q_lds_store_window = + make_tile_window(q_lds_write_view, + Policy::template MakeQLdsBlockDescriptor().get_lengths(), + {0, 0}); + + auto q_lds_read_window = + make_tile_window(q_lds_read_view, + Policy::template MakeQLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeQRegTileDistribution()); + + async_load_tile(q_lds_store_window, q_dram_window); + + // K tile in LDS + const index_t physical_seqlen_k_start = logical_seqlen_k_start; + const index_t physical_seqlen_k_end = logical_seqlen_k_end; + // make sure the first tile is completely located in page-block (page-block size should be + // divisible by kN0) + // relationship between each *_start variables: aligned_physical_seqlen_k_start <= + // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start + const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start; + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); + + auto k_lds_write_view = make_tensor_view( + static_cast(smem_ptr), Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_read_view = make_tensor_view( + static_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + + auto k_lds_write_window = + make_tile_window(k_lds_write_view, + Policy::template MakeKLdsBlockDescriptor().get_lengths(), + {0, 0}); + auto k_lds_read_window = + make_tile_window(k_lds_read_view, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeKRegTileDistribution()); + + // S tile in LDS + auto s_lds = make_tensor_view( + reinterpret_cast(reinterpret_cast(smem_ptr) + + Policy::template GetSmemSizeK()), + Policy::template MakeSLdsBlockDescriptor()); + auto s_write_lds_window = make_tile_window( + s_lds, Policy::template MakeSLdsBlockDescriptor().get_lengths(), {0, 0}); + auto s_read_lds_window = + make_tile_window(s_lds, + Policy::template MakeSLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeSRegTileDistribution()); + + // V tile in LDS + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); + + auto v_lds_write_view = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeS()), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_read_view = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeS()), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_write_window = + make_tile_window(v_lds_write_view, + Policy::template MakeVLdsBlockDescriptor().get_lengths(), + {0, 0}); + + auto v_lds_read_window = + make_tile_window(v_lds_read_view, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeVRegTileDistribution()); + + block_sync_lds_direct_load<0>(); + auto q_tile = load_tile(q_lds_read_window); + + const index_t num_total_loop = + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + + block_sync_lds(); + async_load_tile(k_lds_write_window, k_dram_window); + + constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access(); + constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access(); + + do + { + block_sync_lds(); + async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile + + // move V tile windows + move_tile_window(v_dram_window, {kN0, 0}); + + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + + if constexpr(1 < k0_loops) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + if constexpr(i_k0 == 0) + { + block_sync_lds_direct_load(); + } + else + { + block_sync_lds_direct_load<0>(); + } + + auto k_tile = load_tile(k_lds_read_window); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_tile); + + // loop over along the [K]ey head dimension + move_tile_window(k_dram_window, {0, kK0}); + block_sync_lds(); + async_load_tile(k_lds_write_window, k_dram_window); + }); + // move back to the origin + move_tile_window(k_dram_window, {0, -kK0 * (k0_loops - 1)}); + } + + if constexpr(k0_loops == 1) + { + block_sync_lds_direct_load(); + } + else + { + block_sync_lds_direct_load<0>(); + } + + auto k_tile = load_tile(k_lds_read_window); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_tile); + + if constexpr(kHasUnevenSplits) + { + if(i_total_loops == (num_total_loop - 1)) + { + const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + set_tile_if(s_acc, + -numeric::infinity(), + [&, + physical_seqlen_k_start_ = physical_seqlen_k_start, + physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { + const auto col = k_origin.at(I0) + tile_idx.at(I1); + + { + return physical_seqlen_k_end_ <= col; + } + }); + } + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + + bool need_perpixel_check = + mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number{}, number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(I0) + tile_idx.at(I0); + const auto col = k_origin.at(I0) + tile_idx.at(I1); + return mask.IsOutOfBound(row, col); + }); + } + } + + // move K tile windows after current status checked + // prefetch next-tile along [K]ey sequence length dimension + move_tile_window(k_dram_window, {kN0, 0}); + + block_sync_lds(); + async_load_tile(k_lds_write_window, k_dram_window); + + // Gemm1 + auto s_new = [&]() { + if constexpr(kNWarp > 1) + { + auto s = cast_tile(s_acc); // S{j} + + store_tile(s_write_lds_window, s); + block_sync_lds(); + return load_tile(s_read_lds_window); + } + else + { + return cast_tile(s_acc); // S{j} + } + }(); + + auto m_local = block_tile_reduce( + s_new, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + // Set CrossWarp to false will trigger better strategy on gfx950, but will cause + // performance regression because of un-coexecutable packed math, silent it for now + block_tile_reduce_sync( + m_local, f_max, bool_constant{} /*, bool_constant{}*/); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s_new.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto row_max = scale_s * get_validated_m(m[i_idx]); + sweep_tile_span(p_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); + } + } + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync( + rowsum_p, f_sum, bool_constant{} /*, bool_constant{}*/); + + auto p_tile = make_static_distributed_tensor( + Policy::template MakePRegTileDistribution()); + p_tile.get_thread_buffer() = cast_tile(p_compute).get_thread_buffer(); + + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + } + }(); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds_direct_load(); + + auto v_tile = load_tile_transpose(v_lds_read_window); + + if constexpr(1 < k1_loops) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + gemm_1(o_acc, + get_slice_tile(p_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}), + v_tile); + + // loop over along the [V]alue Sequence length + move_tile_window(v_lds_read_window, {kK1, 0}); + v_tile = load_tile_transpose(v_lds_read_window); + }); + // move back to the origin + move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); + } + + gemm_1(o_acc, + get_slice_tile(p_tile, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{}), + v_tile); + + } while(++i_total_loops < num_total_loop); + + if constexpr(kStoreLSE) + { + // store lse acc + auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); + sweep_tile_span(lse_acc_spans[I0], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } + } + }); + + if(get_thread_local_1d_id() < kM0) + { + store_tile(lse_acc_dram_window_tmp, lse_acc); + } + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } + + // Prefill, double lds + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* __restrict__ smem_ptrk0, + void* __restrict__ smem_ptrk1, + void* __restrict__ smem_ptrv0, + void* __restrict__ smem_ptrv1) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[I0] && + kSubQKHeaddim == QDramBlockWindowTmp{}.get_window_lengths()[I1] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[I0] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[I1] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[I0] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[I1] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I0] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I1], + "wrong!"); + ignore = bias_dram_block_window_tmp; + ignore = position_encoding; + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + auto o_acc = OaccBlockTileType{}; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(o_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + // init M, L + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + const auto q_origin = q_dram_block_window_tmp.get_window_origin(); + const auto [logical_seqlen_k_start, logical_seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(I0), number{}, number{}); + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) + { + const index_t logical_num_total_loop = + integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0); + if(logical_num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse_acc = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse_acc, -numeric::infinity()); + + if(get_thread_local_1d_id() < kM0) + { + store_tile(lse_acc_dram_window_tmp, lse_acc); + } + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + // Q tile in LDS + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution()); + + auto q_lds_write_view = make_tensor_view( + static_cast(smem_ptrk0), + Policy::template MakeQLdsBlockDescriptor()); + + auto q_lds_read_view = make_tensor_view( + static_cast(smem_ptrk0), + Policy::template MakeQLdsBlockDescriptor()); + + auto q_lds_store_window = + make_tile_window(q_lds_write_view, + Policy::template MakeQLdsBlockDescriptor().get_lengths(), + {0, 0}); + + auto q_lds_read_window = + make_tile_window(q_lds_read_view, + Policy::template MakeQLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeQRegTileDistribution()); + + async_load_tile(q_lds_store_window, q_dram_window); + block_sync_lds_direct_load<0>(); + auto q_tile = load_tile(q_lds_read_window); + + // K tile in LDS + const index_t physical_seqlen_k_start = logical_seqlen_k_start; + const index_t physical_seqlen_k_end = logical_seqlen_k_end; + // make sure the first tile is completely located in page-block (page-block size should be + // divisible by kN0) + // relationship between each *_start variables: aligned_physical_seqlen_k_start <= + // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start + const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start; + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); + + auto k_lds_write_view = make_tensor_view( + static_cast(smem_ptrk0), + Policy::template MakeKLdsBlockDescriptor()); + + auto k_lds_read_view = make_tensor_view( + static_cast(smem_ptrk0), + Policy::template MakeKLdsBlockDescriptor()); + + auto k_lds_write_window = + make_tile_window(k_lds_write_view, + Policy::template MakeKLdsBlockDescriptor().get_lengths(), + {0, 0}); + + auto k_lds_read_window = + make_tile_window(k_lds_read_view, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeKRegTileDistribution()); + + // S tile in LDS + auto s_lds = make_tensor_view( + reinterpret_cast(reinterpret_cast(smem_ptrk0) + + Policy::template GetSmemSizeK()), + Policy::template MakeSLdsBlockDescriptor()); + auto s_write_lds_window = make_tile_window( + s_lds, Policy::template MakeSLdsBlockDescriptor().get_lengths(), {0, 0}); + auto s_read_lds_window = + make_tile_window(s_lds, + Policy::template MakeSLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeSRegTileDistribution()); + + // V tile in LDS + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); + + auto v_lds_write_view = make_tensor_view( + reinterpret_cast(static_cast(smem_ptrv0)), + Policy::template MakeVLdsBlockDescriptor()); + + auto v_lds_read_view = make_tensor_view( + reinterpret_cast(static_cast(smem_ptrv0)), + Policy::template MakeVLdsBlockDescriptor()); + + auto v_lds_write_window = + make_tile_window(v_lds_write_view, + Policy::template MakeVLdsBlockDescriptor().get_lengths(), + {0, 0}); + + auto v_lds_read_window = + make_tile_window(v_lds_read_view, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeVRegTileDistribution()); + + // block_sync_lds_direct_load<0>(); + // auto q_tile = load_tile(q_lds_read_window); + + const index_t num_total_loop = + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + block_sync_lds<0>(); + async_load_tile(k_lds_write_window, k_dram_window); + async_load_tile(v_lds_write_window, v_dram_window); + + move_tile_window(k_dram_window, {kN0, 0}); + k_lds_write_window.set_bottom_tensor_view_data_ptr( + static_cast(smem_ptrk1)); + async_load_tile(k_lds_write_window, k_dram_window); + + constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access(); + constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access(); + + constexpr index_t k_lds_insts = k_lds_read_window.get_num_of_access(); + constexpr index_t v_lds_insts = v_lds_read_window.get_num_of_access(); + + block_sync_lds_direct_load(); + auto k_tile = load_tile(k_lds_read_window); + + __builtin_amdgcn_sched_barrier(0); + + auto mainloop = [&](index_t cur_loop) { + const bool is_even_loop = (cur_loop % 2 == 0); + + auto k_lds_write_ptr = is_even_loop ? static_cast(smem_ptrk0) + : static_cast(smem_ptrk1); + auto k_lds_read_ptr = is_even_loop ? static_cast(smem_ptrk1) + : static_cast(smem_ptrk0); + auto v_lds_write_ptr = is_even_loop ? static_cast(smem_ptrv1) + : static_cast(smem_ptrv0); + auto v_lds_read_ptr = is_even_loop ? static_cast(smem_ptrv0) + : static_cast(smem_ptrv1); + + // move V tile windows + block_sync_lds(); + move_tile_window(v_dram_window, {kN0, 0}); + v_lds_write_window.set_bottom_tensor_view_data_ptr(v_lds_write_ptr); + async_load_tile(v_lds_write_window, v_dram_window); + + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + + if constexpr(1 < k0_loops) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + // loop over along the [K]ey head dimension + move_tile_window(k_lds_read_window, {0, kK0}); + auto k_tile_switch = load_tile(k_lds_read_window); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_tile); + + k_tile = k_tile_switch; + }); + // move back to the origin + move_tile_window(k_lds_read_window, {0, -kK0 * (k0_loops - 1)}); + } + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_tile); + + block_sync_lds_direct_load(); + v_lds_read_window.set_bottom_tensor_view_data_ptr(v_lds_read_ptr); + auto v_tile = load_tile_transpose(v_lds_read_window); + + if constexpr(kHasUnevenSplits) + { + if(i_total_loops == (num_total_loop - 1)) + { + const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + set_tile_if(s_acc, + -numeric::infinity(), + [&, + physical_seqlen_k_start_ = physical_seqlen_k_start, + physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { + const auto col = k_origin.at(I0) + tile_idx.at(I1); + + { + return physical_seqlen_k_end_ <= col; + } + }); + } + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + + bool need_perpixel_check = + mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number{}, number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(I0) + tile_idx.at(I0); + const auto col = k_origin.at(I0) + tile_idx.at(I1); + return mask.IsOutOfBound(row, col); + }); + } + } + + // Gemm1 + auto s_new = [&]() { + if constexpr(kNWarp > 1) + { + auto s = cast_tile(s_acc); // S{j} + + store_tile(s_write_lds_window, s); + block_sync_lds(); + return load_tile(s_read_lds_window); + } + else + { + return cast_tile(s_acc); // S{j} + } + }(); + + auto m_local = block_tile_reduce( + s_new, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync( + m_local, f_max, bool_constant{} /*, bool_constant{}*/); + + static_for<0, 12, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ + }); + + static_for<0, 4, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ + }); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s_new.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto row_max = scale_s * get_validated_m(m[i_idx]); + sweep_tile_span(p_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); + } + } + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync( + rowsum_p, f_sum, bool_constant{} /*, bool_constant{}*/); + + auto p_tile = make_static_distributed_tensor( + Policy::template MakePRegTileDistribution()); + p_tile.get_thread_buffer() = cast_tile(p_compute).get_thread_buffer(); + + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + } + }(); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + move_tile_window(k_dram_window, {kN0, 0}); + k_lds_write_window.set_bottom_tensor_view_data_ptr(k_lds_write_ptr); + async_load_tile(k_lds_write_window, k_dram_window); + + if constexpr(1 < k1_loops) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + // loop over along the [V]alue Sequence length + move_tile_window(v_lds_read_window, {kK1, 0}); + auto v_tile_switch = load_tile_transpose(v_lds_read_window); + + gemm_1(o_acc, + get_slice_tile(p_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}), + v_tile); + + v_tile = v_tile_switch; + }); + // move back to the origin + move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); + } + + gemm_1(o_acc, + get_slice_tile(p_tile, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{}), + v_tile); + + block_sync_lds_direct_load(); + k_lds_read_window.set_bottom_tensor_view_data_ptr(k_lds_read_ptr); + k_tile = load_tile(k_lds_read_window); + + static_for<0, 12, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ + }); + + static_for<0, 4, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ + }); + }; + + do + { + mainloop(i_total_loops); + i_total_loops++; + } while(i_total_loops < num_total_loop); + + if constexpr(kStoreLSE) + { + // store lse acc + auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); + sweep_tile_span(lse_acc_spans[I0], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } + } + }); + + if(get_thread_local_1d_id() < kM0) + { + store_tile(lse_acc_dram_window_tmp, lse_acc); + } + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp new file mode 100644 index 0000000000..ed22758566 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp @@ -0,0 +1,823 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp" + +// can remove all bank conflicts, but drop the performance for some cases +// Probably it is limited by compiler optimization. +#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 +namespace ck_tile { +// This pipeline is qkv all located in LDS +struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy + : BlockFmhaPipelineQXKSVSCustomPolicy +{ + using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + // this should align with MakeQDramTileDistribution() + constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + return min(ElemPerThread, MaxVectorSize); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() + { + using OaccDataType = remove_cvref_t; + + return static_cast(16 / sizeof(OaccDataType)); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType); + + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + return min(ElemPerThread, MaxVectorSize); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType); + + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + return min(ElemPerThread, MaxVectorSize); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + if constexpr(!BypassLDS) + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + + constexpr index_t KPerThread = kMaxVecLoad; + constexpr index_t KThreads = kKPerBlock / KPerThread; + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto q_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, + sequence<0, 0>>{}; + + constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); + + return q_block_dstr; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() + { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = + LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0; + + constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + + constexpr index_t K1 = min(MaxVectorSize, ElemPerThread); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + // Read M first, then K + // This is the same data consume order as BlockGEMM + constexpr auto q_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); + + return q_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() + { + // TODO: this is for 3d layout + using QDataType = remove_cvref_t; + return static_cast(16 / sizeof(QDataType)); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t kKPack = GetSmemKPackQ(); + + constexpr auto q_lds_block_desc = [&]() { + if constexpr(Xor) + { +#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::QDataType); + constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock; + + if constexpr(XorLengthFold > 1) + { + constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor( + q_lds_block_desc_naive, + make_tuple( + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + constexpr auto q_lds_block_desc_tmp = transform_tensor_descriptor( + q_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + q_lds_block_desc_tmp, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else +#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + { + constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor( + make_tuple( + number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor( + q_lds_block_desc_naive, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + return transform_tensor_descriptor( + q_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + } + }(); + + return q_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = + LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0; + + constexpr index_t kKPack = GetSmemKPackK(); + + constexpr auto k_lds_block_desc = [&]() { + if constexpr(Xor) + { +#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::KDataType); + constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock; + + if constexpr(XorLengthFold > 1) + { + constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor( + k_lds_block_desc_naive, + make_tuple( + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + constexpr auto k_lds_block_desc_tmp = transform_tensor_descriptor( + k_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + k_lds_block_desc_tmp, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else +#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + { + constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor( + make_tuple( + number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor( + k_lds_block_desc_naive, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + return transform_tensor_descriptor( + k_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + } + }(); + + return k_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t kKPack = GetSmemKPackV(); + + constexpr auto v_lds_block_desc = [&]() { + if constexpr(Xor) + { + constexpr auto XorGroupSize = + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); + +#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::VDataType); + constexpr auto XorLengthFold = LDSLayerSize / kNPerBlock; + + if constexpr(XorLengthFold > 1) + { + constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor( + v_lds_block_desc_naive, + make_tuple( + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + constexpr auto v_lds_block_desc_tmp = transform_tensor_descriptor( + v_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + v_lds_block_desc_tmp, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else +#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + { + constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor( + v_lds_block_desc_naive, + make_tuple(make_xor_transform(make_tuple( + number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + return transform_tensor_descriptor( + v_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + } + }(); + + return v_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>, + GemmLoopOrder::MNK>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + true>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetPVBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>, + GemmLoopOrder::KMN>; + + using WarpGemm = WarpGemmMfmaDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + ((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 && + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) || + (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 && + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16)) + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + // Read N first, then K + // This is the same data consume order as BlockGEMM + constexpr auto k_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); + + return k_block_dstr; + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType); + + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + + constexpr index_t NPerThread = kMaxVecLoad; + constexpr index_t NThreads = kNPerBlock / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t KPerThread = kKPerBlock / (KThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + // Read M first, then K + // This is the same data consume order as BlockGEMM + constexpr auto p_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, + sequence<0, 0>>{}; + + constexpr auto p_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + p_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto p_block_dstr = make_static_tile_distribution(p_block_dstr_encode); + + return p_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + // Read N first, then K + // This is the same data consume order as BlockGEMM + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto v_block_dstr = + make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(v_block_dstr_encode), + typename Problem::VDataType>::TransposedDstrEncode{}); + + return v_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS() + { + using SDataType = remove_cvref_t; + return static_cast(16 / sizeof(SDataType)); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kNPack = GetSmemNPackS(); + + constexpr auto s_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * kNPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto s_lds_block_desc = transform_tensor_descriptor( + s_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return s_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + // static_assert(MWarp == 1, "Check failed!"); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kTileK = Problem::BlockFmhaShape::kN0; + + // K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm + constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K1 = kKPerBlock / (K2 * K3); + constexpr index_t K0 = kTileK / kKPerBlock; + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + constexpr auto s2_block_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<2, 2>>, + sequence<1, 2, 2, 2>, + sequence<0, 0, 1, 3>>{}; + + constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding); + + return s2_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() + { + return MakeQLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::QDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() + { + return MakeKLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() + { + return MakeVLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeS() + { + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + return NWarp > 1 ? MakeSLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::SaccDataType) + : 0; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + // Alignment on gfx950 is 1280 Bytes + // Alignment before gfx950 is 512 Bytes. + return max(GetSmemSizeQ(), + GetSmemSizeK() + GetSmemSizeS() + GetSmemSizeV()); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 3489d6f9a1..e2cea97f9a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -383,23 +383,31 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - return 16 / sizeof(VDataType); + using VDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t kMaxVecLoad = + min(total_pixels, static_cast(16 / sizeof(VDataType))); + + return kMaxVecLoad; } template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { - using VLayout = remove_cvref_t; - using VDataType = remove_cvref_t; + using VLayout = remove_cvref_t; + using VDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t kMaxVecLoad = + min(total_pixels, static_cast(16 / sizeof(VDataType))); + if constexpr(std::is_same_v) { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t kMaxVecLoad = - min(total_pixels, static_cast(16 / sizeof(VDataType))); constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) @@ -410,7 +418,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - using WarpGemm = typename Traits::WarpGemm; - using BlockGemmShape = typename Traits::BlockGemmShape; + using WarpGemm = typename Traits::WarpGemm; + using BlockGemmShape = typename Traits::BlockGemmShape; + static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -86,17 +89,36 @@ struct BlockGemmARegBRegCRegV1 } else { - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) + { + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, + sequence<0, 0>>{}; - return a_block_dstr_encode; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) + { + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } } } @@ -118,17 +140,33 @@ struct BlockGemmARegBRegCRegV1 } else { - constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) + { + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - return b_block_dstr_encode; + return b_block_dstr_encode; + } + else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) + { + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + return b_block_dstr_encode; + } } } @@ -213,40 +251,82 @@ struct BlockGemmARegBRegCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) + { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // read C warp tensor from C block tensor - using c_iter_idx = std:: - conditional_t, sequence>; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // read C warp tensor from C block tensor + using c_iter_idx = std::conditional_t, + sequence>; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); }); }); - }); + } + else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) + { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } } CK_TILE_DEVICE static constexpr auto MakeCBlockTile() diff --git a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp index fd5211a59a..d0be065fc9 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { @@ -13,7 +14,8 @@ template + GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN, + index_t NumWaveGroups_ = 1> struct BlockGemmProblem { using ADataType = remove_cvref_t; @@ -21,8 +23,9 @@ struct BlockGemmProblem using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t NumWaveGroups = NumWaveGroups_; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t NumWaveGroups = NumWaveGroups_; + static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index b18bf603a9..b3c86b9456 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -39,6 +39,12 @@ enum struct TailNumber Full, }; +enum struct GemmLoopOrder +{ + KMN, + MNK, +}; + } // namespace ck_tile inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 52bd07c9e2..c628614b54 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -14,10 +14,11 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1, + GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN> struct GemmPipelineProblemBase { using Traits = remove_cvref_t; @@ -45,9 +46,10 @@ struct GemmPipelineProblemBase static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadK = Traits::kPadK; - static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; - static constexpr auto Scheduler = GemmPipelineScheduler::Default; - static constexpr index_t VectorLoadSize = Traits::_VectorSize; + static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; + static constexpr auto Scheduler = GemmPipelineScheduler::Default; + static constexpr index_t VectorLoadSize = Traits::_VectorSize; + static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; // In the base situation, the Preshuffle setting should be false. static constexpr bool Preshuffle = false; @@ -167,10 +169,11 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1, + GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN> using GemmPipelineProblem = GemmPipelineProblemBase; + VectorSizeB_, + BlockGemmLoopOrder_>; template + GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, + bool HasHotLoop_ = true, + TailNumber TailNum_ = TailNumber::Full, + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1, + GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN> struct UniversalGemmPipelineProblem { using Traits = remove_cvref_t; @@ -224,8 +229,9 @@ struct UniversalGemmPipelineProblem static constexpr auto Scheduler = Scheduler_; static constexpr bool Preshuffle = Traits::Preshuffle; - static constexpr index_t VectorSizeA = VectorSizeA_; - static constexpr index_t VectorSizeB = VectorSizeB_; + static constexpr index_t VectorSizeA = VectorSizeA_; + static constexpr index_t VectorSizeB = VectorSizeB_; + static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index fb191d565d..d1deaf9e0e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -104,6 +104,10 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = 1>>; #endif +using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution = + WarpGemmImpl>>; + #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>; #endif +using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution = + WarpGemmImpl>>; + #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp16 2:4 structural sparsity // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity @@ -74,6 +76,8 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 434be9f84a..7a10d1fa56 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -14,10 +14,14 @@ namespace ck_tile { * Y dim must have at least one dim not been reduced */ // synchronize reduce result (cross lane reduction and broadcast on replicated dimension) -template +template CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, const ReduceFunc& reduce_func, - bool_constant = {}) + bool_constant = {}, + bool_constant = {}) { using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using DstrEncode = typename Dstr::DstrEncode; @@ -56,14 +60,24 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, // reduction sweep forward static_for<0, nstage, 1>{}([&](auto istage) { - constexpr index_t lid_delta = - lid_over_rid_derivative * (1 << (nstage - istage - 1)); + if constexpr(CrossWarp) + { + constexpr index_t lid_delta = + lid_over_rid_derivative * (1 << (nstage - istage - 1)); - // pull data from remote lane - const auto v_remote = warp_shuffle_down(v_local, lid_delta); + // pull data from remote lane + const auto v_remote = warp_shuffle_down(v_local, lid_delta); - // reduce - v_local = reduce_func(v_local, v_remote); + // reduce + v_local = reduce_func(v_local, v_remote); + } + else + { + // pull data from remote lane + const auto v_swapped_regs = warp_shuffle_down_pair(v_local); + // reduce + v_local = reduce_func(v_swapped_regs.at(0), v_swapped_regs.at(1)); + } }); } }); From 5b39de4bb61a3f0399fcd384f3a82c5e6ce28e5e Mon Sep 17 00:00:00 2001 From: asleepzzz Date: Tue, 12 Aug 2025 20:27:10 +0800 Subject: [PATCH 3/9] Revert "Optimize fmha fwd decode & prefill for gfx950 (#2641)" (#2670) This reverts commit b7322a521a91fe4762701237f0243dd2c94b7644. --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 - .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 147 +- example/ck_tile/01_fmha/fmha_fwd.cpp | 2 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 3 - .../ck_tile/01_fmha/script/benchmark_fwd.sh | 11 + .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 21 +- .../core/arch/amd_buffer_addressing.hpp | 17 +- .../arch/amd_buffer_addressing_builtins.hpp | 17 +- include/ck_tile/core/arch/arch.hpp | 27 +- include/ck_tile/core/arch/utility.hpp | 15 - include/ck_tile/core/config.hpp | 10 - include/ck_tile/core/numeric/bfloat16.hpp | 11 - include/ck_tile/core/numeric/pk_fp4.hpp | 2 +- include/ck_tile/core/numeric/pk_int4.hpp | 2 +- include/ck_tile/core/numeric/vector_type.hpp | 12 +- .../unary_element_wise_operation.hpp | 7 + include/ck_tile/ops/fmha.hpp | 2 - .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 1530 +++++------------ ...block_fmha_bwd_pipeline_default_policy.hpp | 24 +- .../pipeline/block_fmha_pipeline_enum.hpp | 7 - .../pipeline/block_fmha_pipeline_problem.hpp | 2 - ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 1177 ------------- ..._pipeline_qr_ks_vs_async_trload_policy.hpp | 823 --------- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 30 +- .../block/block_gemm_areg_breg_creg_v1.hpp | 178 +- .../ops/gemm/block/block_gemm_problem.hpp | 9 +- .../gemm_pipeline_ag_bg_cr_scheduler.hpp | 6 - .../gemm/pipeline/gemm_pipeline_problem.hpp | 48 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 8 - .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 4 - .../ck_tile/ops/reduce/block/block_reduce.hpp | 30 +- 31 files changed, 639 insertions(+), 3545 deletions(-) delete mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp delete mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 42a9d5148a..6fca800c90 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -115,7 +115,6 @@ PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", - "qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", } PIPELINE_ENUM_MAP = { @@ -124,7 +123,6 @@ PIPELINE_ENUM_MAP = { "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index ce35c6a2a7..269af4e6a7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -12,7 +12,6 @@ from typing import List, Optional, Tuple from codegen.cmake_config import * from codegen.cpp_symbol_map import * -from codegen.utils import update_file DTYPE_BITS = { @@ -84,7 +83,6 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< {F_mode}, fmha_variant_{F_idx}, fmha_mask_{F_idx}, - {F_trload}, fmha_trait_{F_idx}>; using fmha_pipeline_{F_idx} = {F_pipeline}< @@ -99,7 +97,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; #include @@ -163,19 +161,12 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; - - const bool has_load_tr = ck_tile::is_load_tr_supported(); {F_dispatch} return r; }} """ -FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ -{F_dtype_case} - }} -""" - FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} @@ -186,8 +177,8 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < """ FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && - ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; return fmha_fwd_(s, a); }} """ @@ -230,7 +221,6 @@ class FmhaFwdApiTrait: dpad : str dvpad : str skip : str - tr_load : str constraint : CppConstraint @property @@ -241,19 +231,13 @@ class FmhaFwdApiTrait: @property def scheck(self) -> str: if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag in ['qr_async', 'qr_async_trload']: + if self.pipeline_tag == 'qr_async': if self.spad == 't' : return 'true' # always support else : return 'true' elif self.pipeline_tag in ['qr', 'qs']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False - - @property - def seqtune(self) -> str: - if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true - else: - return f'a.seqlen_q <= {self.bm0}' @property def skcheck(self) -> str: @@ -264,9 +248,6 @@ class FmhaFwdApiTrait: elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag == 'qr_async_trload': - if self.skpad == 't' : return 'true' - else: return 'true' else: assert False @property @@ -275,7 +256,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + elif self.pipeline_tag in ['qr', 'qs']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {bk0submax} == 0' @@ -287,7 +268,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + elif self.pipeline_tag in ['qr', 'qs']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {bk0submax} == 0' @@ -309,7 +290,6 @@ class FmhaFwdPipeline: F_squant : str # F_mask : str # value from MASK_MAP F_skip : str # true/false - F_trload : str # true/false F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -351,9 +331,6 @@ class FmhaFwdPipeline: if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' - - if self.F_trload == 't' : n += '_trload' - else: n += '_ntrload' return n @@ -374,39 +351,31 @@ class FmhaFwdApiPool: @property def api(self) -> str: - tr_load_cond_map = { - "t": "has_load_tr", - "f": "true" - } - - per_tr_load =str() - for tr_load in ["t", "f"]: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][(hdim, hdim_v)] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes) - if not per_tr_load: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][(hdim, hdim_v)] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: # empty string we add some ignore to suppress warning in api - per_tr_load += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load) + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) @dataclass class FmhaFwdTileSize: @@ -489,8 +458,7 @@ class FmhaFwdKernel: F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], - F_trload = BOOL_MAP[self.F_pipeline.F_trload]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) @property def name(self) -> str: @@ -526,7 +494,6 @@ class FmhaFwdKernel: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, skip=self.F_pipeline.F_skip, - tr_load=self.F_pipeline.F_trload, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) class KernelComponentFactory: @@ -536,15 +503,10 @@ class KernelComponentFactory: def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], @@ -572,27 +534,34 @@ class KernelComponentFactory: if dtype in ['fp16', 'bf16']: for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256 and hdim_v == 256: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + # if True: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": - pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) - pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f')) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -630,12 +599,6 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue - if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): - continue # logits_soft_cap is only allowed if no bias if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): continue @@ -702,10 +665,10 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl return (api_pool, gen) def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - update_file(autogen_dir / kernel.filename, kernel.template) + (autogen_dir / kernel.filename).write_text(kernel.template) def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: - update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index d0f8e3798c..c0e4dc3d30 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1135,7 +1135,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec - << " GB/s" << std::flush << std::endl; + << " GB/s" << std::flush; if(do_validation == 0) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index df1e9e5699..81dda692ea 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/host/device_prop.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/fmha.hpp" @@ -1029,7 +1028,6 @@ template struct fmha_fwd_traits_ { @@ -1054,7 +1052,6 @@ struct fmha_fwd_traits_ static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; - static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; }; diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 88c16cceb6..599c595a75 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -18,3 +18,14 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn done done done + +for perm in 0 1 ; do + +$EXE -prec=fp8 -squant=1 -b=32 -h=16 -d=128 -s=512 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=16 -h=16 -d=128 -s=1024 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=8 -h=16 -d=128 -s=2048 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=4 -h=16 -d=128 -s=4096 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=2 -h=16 -d=128 -s=8192 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=1 -h=16 -d=128 -s=16384 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 + +done \ No newline at end of file diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index dc2be933bd..b867cd6c07 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -42,6 +42,7 @@ run_fp16_bf16_tests() { for prec in "fp16" "bf16" ; do for mode in 1 0 ; do for perm in 0 1 ; do + for vlayout in "r" "c" ; do for hdim in 32 64 128 256 ; do for lse in 0 1 ; do for bias in "n" "e" "a" ; do @@ -50,16 +51,16 @@ run_fp16_bf16_tests() { for page_block_size in $PAGE_BLOCK_SIZE ; do for cache_batch_idx in $CACHE_BATCH_IDX ; do - # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done ; done ; done diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 07be65a150..35da19cd3e 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -41,6 +41,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); + r.x = __builtin_amdgcn_readfirstlane(r.x); + r.y = __builtin_amdgcn_readfirstlane(r.y); + r.z = __builtin_amdgcn_readfirstlane(r.z); + r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } @@ -1314,17 +1318,6 @@ enum struct amd_buffer_coherence_enum glc = 1, slc = 2, glc_slc = 3, - // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 - // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system - // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse - WAVE_NT0 = 0, - WAVE_NT1 = 2, - GROUP_NT0 = 1, - GROUP_NT1 = 3, - DEVICE_NT0 = 8, - DEVICE_NT1 = 10, - SYSTEM_NT0 = 9, - SYSTEM_NT1 = 11, }; template & src_thread_ #if defined(__gfx950__) template -__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) +__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) { static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index c64b296408..8c3bc0bc36 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -32,6 +32,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); + r.x = __builtin_amdgcn_readfirstlane(r.x); + r.y = __builtin_amdgcn_readfirstlane(r.y); + r.z = __builtin_amdgcn_readfirstlane(r.z); + r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } @@ -1182,17 +1186,6 @@ enum struct amd_buffer_coherence_enum glc = 1, slc = 2, glc_slc = 3, - // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 - // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system - // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse - WAVE_NT0 = 0, - WAVE_NT1 = 2, - GROUP_NT0 = 1, - GROUP_NT1 = 3, - DEVICE_NT0 = 8, - DEVICE_NT1 = 10, - SYSTEM_NT0 = 9, - SYSTEM_NT1 = 11, }; template -__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) +__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) { static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index f0e9518120..ab42ec8617 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -89,6 +89,21 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; } CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; } +CK_TILE_DEVICE void block_sync_lds() +{ +#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM + // asm volatile("\ + // s_waitcnt lgkmcnt(0) \n \ + // s_barrier \ + // " ::); + + __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_barrier(); +#else + __syncthreads(); +#endif +} + CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0) { #ifdef __gfx12__ @@ -159,18 +174,6 @@ CK_TILE_DEVICE void s_waitcnt_barrier() __builtin_amdgcn_s_barrier(); } -template -CK_TILE_DEVICE void block_sync_lds() -{ - s_waitcnt_barrier(); -} - -template -CK_TILE_DEVICE void block_sync_lds_direct_load() -{ - s_waitcnt_barrier(); -} - CK_TILE_DEVICE void s_nop(index_t cnt = 0) { #if 1 diff --git a/include/ck_tile/core/arch/utility.hpp b/include/ck_tile/core/arch/utility.hpp index 93008f8525..7184f99521 100644 --- a/include/ck_tile/core/arch/utility.hpp +++ b/include/ck_tile/core/arch/utility.hpp @@ -59,21 +59,6 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta) #endif } -template -CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local) -{ - static_assert(sizeof(T) == sizeof(int32_t), "wrong!"); - - const int32x2_t x = __builtin_amdgcn_permlane32_swap( - bit_cast(v_local), bit_cast(v_local), false, false); - - thread_buffer v; - v(0) = bit_cast(x[0]); - v(1) = bit_cast(x[1]); - - return v; -} - template CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane) { diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index e472bd01e5..c471f416c3 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -191,16 +191,6 @@ #endif #endif -// use llvm builtin bf16 data type after ROCm 6.5 -#ifndef CK_TILE_USE_LLVM_BUILTIN_BF16 -#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \ - (HIP_VERSION_MAJOR >= 7) -#define CK_TILE_USE_LLVM_BUILTIN_BF16 1 -#else -#define CK_TILE_USE_LLVM_BUILTIN_BF16 0 -#endif -#endif - #ifndef CK_TILE_DEBUG_LOG #define CK_TILE_DEBUG_LOG 0 #endif diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 245fb7244f..6f31468809 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -6,9 +6,6 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/numeric.hpp" -#if CK_TILE_USE_LLVM_BUILTIN_BF16 -#include -#endif #include #pragma once @@ -105,11 +102,7 @@ struct native_t using bf16_t = bfloat16_t; using bf16_raw_t = typename bf16_t::raw_type; #else -#if CK_TILE_USE_LLVM_BUILTIN_BF16 -using bfloat16_t = __bf16; -#else using bfloat16_t = ushort; -#endif using bf16_t = bfloat16_t; using bf16_raw_t = uint16_t; #endif @@ -287,11 +280,7 @@ template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant = {}) { -#if defined(__gfx950__) - return static_cast(f); -#else return bit_cast(float_to_bf16_raw(f, constant{})); -#endif } template using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); -using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); +using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) { diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index bbd3d53827..58bdb43b08 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -131,12 +131,12 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64))); // bf16 // using bf16_t = ... -using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); -using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4))); -using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8))); -using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16))); -using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32))); -using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64))); +using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); +using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4))); +using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8))); +using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16))); +using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32))); +using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64))); // i32 // using int32_t = ... diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index b69c167315..0e385901ed 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -330,6 +330,13 @@ struct PassThrough y = type_convert(x); } + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const + { + y = type_convert(x); + } + template <> CK_TILE_HOST_DEVICE void operator()(float& y, const ck_tile::fp16_t& x) const diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 69f645b850..d8dd5db12e 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -52,8 +52,6 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 5b3d38d3e7..8d257a3329 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -13,7 +13,6 @@ #include #include -#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] @@ -62,14 +61,6 @@ struct FmhaFwdKernel static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; - static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad; -#if defined(__gfx950__) - static constexpr bool kIsAvialable = true; -#else - static constexpr bool kIsAvialable = !kUseTrLoad; -#endif - static constexpr std::string_view kPipelineName = FmhaPipeline::name; - // clang-format off template struct t2s; template <> struct t2s { static constexpr const char * name = "fp32"; }; @@ -109,7 +100,7 @@ struct FmhaFwdKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload"); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); #undef _SS_ #undef _TS_ // clang-format on @@ -1045,1142 +1036,455 @@ struct FmhaFwdKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { - if constexpr(kIsAvialable) - run_(std::move(kargs)); - } + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; - CK_TILE_DEVICE void run_(Kargs kargs) const - { - if constexpr(kPipelineName != "qr_async_trload") + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsGroupMode) { - // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - // divide problem - const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); - const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; - - if constexpr(kIsGroupMode) + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) { - // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - if constexpr(std::is_same_v) - { - batch_offset_v = key_start * kargs.stride_v; - } - else - { - batch_offset_v = key_start; - } - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = query_start * kargs.stride_bias; - } - if constexpr(kStoreLSE) - { - batch_offset_lse = query_start; - } - if constexpr(kHasDropout) - { - batch_offset_randval = query_start * kargs.stride_randval; - } - batch_offset_o = query_start * kargs.stride_o; - - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; - - if constexpr(kSkipMinSeqlenQ) - { - if(kargs.seqlen_q <= kargs.min_seqlen_q) - { - return; - } - } - - // # of required blocks is different in each groups, terminate unnecessary blocks - // earlier - if(kargs.seqlen_q <= i_m0) - { - return; - } - - if(kargs.seqlen_k_ptr != nullptr) - { - kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; - } - else - { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; - } + batch_offset_v = key_start * kargs.stride_v; } else { - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + batch_offset_v = key_start; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + if constexpr(kSkipMinSeqlenQ) + { + if(kargs.seqlen_q <= kargs.min_seqlen_q) { - batch_offset_bias = - static_cast(i_batch) * kargs.batch_stride_bias; + return; } - if constexpr(kStoreLSE) - { - batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; - } - if constexpr(kHasDropout) - { - batch_offset_randval = - static_cast(i_batch) * kargs.batch_stride_randval; - } - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } - // for simplicity, batch stride we just modify the pointer - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_q + - batch_offset_q; - const KDataType* k_ptr = - reinterpret_cast(kargs.k_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + - batch_offset_k; - const VDataType* v_ptr = - reinterpret_cast(kargs.v_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + - batch_offset_v; - ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_o + - batch_offset_o; + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } - // Q/K/V DRAM and DRAM window - const auto q_dram = [&]() { - const auto q_dram_naive = make_naive_tensor_view( - q_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_q, 1), - number{}, - number<1>{}); - if constexpr(FmhaPipeline::kQLoadOnce) - { - return pad_tensor_view(q_dram_naive, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - const auto k_dram = [&]() { - const auto k_dram_naive = make_naive_tensor_view( - k_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), - number{}, - number<1>{}); - - constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; - return pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - const auto v_dram = [&]() { - if constexpr(std::is_same_v) - { - const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_v), - make_tuple(kargs.stride_v, 1), - number{}, - number<1>{}); - - const auto v_dram_transposed = transform_tensor_view( - v_dram_naive, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_k)), - make_tuple(sequence<1>{}, sequence<0>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; - return pad_tensor_view( - v_dram_transposed, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.hdim_v, kargs.seqlen_k), - make_tuple(kargs.stride_v, 1), - number{}, - number<1>{}); - - constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; - return pad_tensor_view( - v_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - - auto q_dram_window = make_tile_window( - q_dram, - [&]() { - if constexpr(FmhaPipeline::kQLoadOnce) - return make_tuple(number{}, - number{}); - else - return make_tuple(number{}, number{}); - }(), - {i_m0, 0}); - - auto k_dram_window = make_tile_window( - k_dram, - make_tuple(number{}, number{}), - {0, 0}); - - auto v_dram_window = make_tile_window( - v_dram, - make_tuple(number{}, number{}), - {i_n1, 0}); - /// FIXME: Before C++20, capturing structured binding variables are not supported. - /// Remove following copy capture of the 'i_nhead' if in C++20 - const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto bias_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - const BiasDataType* bias_ptr = - reinterpret_cast(kargs.bias_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_bias + - batch_offset_bias; - - const auto bias_dram = [&]() { - const auto bias_dram_naive = - make_naive_tensor_view( - bias_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_bias, 1), - number{}, - number<1>{}); - - return pad_tensor_view(bias_dram_naive, - bias_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(bias_dram_window_lengths); - } - }(); - - // lse - auto lse_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto lse_dram_window_lengths = make_tuple(number{}); - if constexpr(kStoreLSE) - { - LSEDataType* lse_ptr = - reinterpret_cast(kargs.lse_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_lse + - batch_offset_lse; - - const auto lse_dram = [&]() { - const auto lse_dram_naive = - make_naive_tensor_view( - lse_ptr, - make_tuple(kargs.seqlen_q), - make_tuple(1), - number<1>{}, - number<1>{}); - - return pad_tensor_view( - lse_dram_naive, lse_dram_window_lengths, sequence{}); - }(); - - return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); - } - else - { - return make_null_tile_window(lse_dram_window_lengths); - } - }(); - - auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { - if constexpr(kHasDropout) - { - return BlockDropout{i_batch_, - i_nhead_, - kargs.num_head_q, - kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val - : *kargs.drop_seed.ptr, - kargs.is_drop_seed_offset_from_host - ? kargs.drop_offset.val - : *kargs.drop_offset.ptr, - kargs.rp_undrop, - kargs.p_undrop_in_uint8_t, - kargs.is_store_randval}; - } - else - { - return NullBlockDropout{}; - }; - }(); - - auto randval_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto randval_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(kHasDropout) - { - RandValOutputDataType* rand_val_ptr = - reinterpret_cast(kargs.rand_val_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_randval + - batch_offset_randval; - - const auto randval_dram = [&]() { - const auto randval_dram_naive = - make_naive_tensor_view( - rand_val_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_randval, 1), - number<1>{}, - number<1>{}); - - return pad_tensor_view(randval_dram_naive, - randval_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(randval_dram_window_lengths); - } - }(); - - FmhaMask mask = [&]() { - if constexpr(kHasMask) - return ck_tile::make_generic_attention_mask_from_lr_window( - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); - else - return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; - }(); - - // WA i_batch capture structure binding before c++20 - auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - // data loading, shared by entire wg - // TODO: how to use s_read? - SaccDataType slope = - *(reinterpret_cast(kargs.alibi_slope_ptr) + - i_batch_ * kargs.alibi_slope_stride + i_nhead_); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; -#endif - if constexpr(kHasMask) - { - return make_alibi_from_lr_mask(slope, - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type); - } - else - { - return Alibi{ - slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; - } - } - else - { - return EmptyPositionEncoding{}; - } - }(); - - AttentionVariant variant; - const auto variant_params = [&] { - if constexpr(kHasLogitsSoftCap) - { - return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; - } - else - { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; - } - }(); - - BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - - auto o_acc_tile = [&]() { - if constexpr(kDoFp8StaticQuant) - { - return FmhaPipeline{}( - q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - randval_dram_window, - lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{kargs.scale_p}, // p_compute_element_func - composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); - } - else - { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - randval_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); - } - }(); - - // O DRAM and O DRAM window - auto o_dram = [&]() { - const auto o_dram_naive = make_naive_tensor_view( - o_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.stride_o, 1), - number{}, - number<1>{}); - - return pad_tensor_view( - o_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - auto o_dram_window = make_tile_window( - o_dram, - make_tuple(number{}, number{}), - {i_m0, i_n1}); - - EpiloguePipeline{}(o_dram_window, o_acc_tile); + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } } else { - // TODO: Refine the logical here. - // In Decode case - // 1. we don't expect KV data reused by different ThreadGroups, bypass the cache - // 2. limit the LDS usage, as we want higher occupancy - // In Prefill case - // 1. we expect KV data reused by different ThreadGroups, use cache - // 2. use more LDS, as we want better memory latency hiding - // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the - // cache - constexpr bool PrefillCase = FmhaPipeline::kM0 >= 128; - // divide problem - const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - - const index_t i_m0 = i_tile_m * FmhaPipeline::kM0; - const index_t i_n1 = i_tile_n * FmhaPipeline::kN1; - - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; // unused for paged-kvcache - long_index_t batch_offset_v = 0; // unused for paged-kvcache - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; - // index_t kv_l2p_offset = - // 0; // logical-to-physical offset of seqlen_k coordinate. only used for - // paged-kvcache - - if constexpr(kIsGroupMode) + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - if constexpr(std::is_same_v) + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_k), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch_, + i_nhead_, + kargs.num_head_q, + kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val + : *kargs.drop_seed.ptr, + kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val + : *kargs.drop_offset.ptr, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + kargs.is_store_randval}; + } + else + { + return NullBlockDropout{}; + }; + }(); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) { - batch_offset_v = key_start * kargs.stride_v; + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); } else { - batch_offset_v = key_start; - } - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = query_start * kargs.stride_bias; - } - - batch_offset_lse = query_start; - batch_offset_o = query_start * kargs.stride_o; - - // get real # queries & # keys under group mode - kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; - - // # of required blocks is different in each groups, terminate unnecessary blocks - // earlier - if(kargs.seqlen_q <= i_m0) - { - return; - } - - if(kargs.seqlen_k_ptr != nullptr) - { - kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; - } - else - { - kargs.seqlen_k = - kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; } } else { - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - if constexpr(kStoreLSE) - { - batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; - } - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; - - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = - static_cast(i_batch) * kargs.batch_stride_bias; - } + return EmptyPositionEncoding{}; } + }(); - // for simplicity, batch stride we just modify the pointer - const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk; + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_q + - batch_offset_q; - const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + - static_cast(i_nhead_k) * kargs.nhead_stride_k + - batch_offset_k; - const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + - static_cast(i_nhead_k) * kargs.nhead_stride_v + - batch_offset_v; + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_o + - batch_offset_o; + auto o_acc_tile = [&]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{kargs.scale_p}, // p_compute_element_func + composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } + }(); - // Q/K/V DRAM and DRAM window - const auto q_dram = [&] { - const auto q_dram_naive = [&] { - { - return make_naive_tensor_view( - q_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_q, 1), - number{}, - number<1>{}); - } - }(); + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); - if constexpr(FmhaPipeline::kQLoadOnce) - { - const auto seqlen_q = kargs.seqlen_q; - const auto q_dram_pad = pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr index_t LDSLayerSize = 256 / sizeof(QDataType); - constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); - - if constexpr(XorLengthFold > 1) - { - const auto q_dram_unmerged = transform_tensor_view( - q_dram_pad, - make_tuple( - make_unmerge_transform( - make_tuple(seqlen_q / XorLengthFold, XorLengthFold)), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto q_dram_merged = transform_tensor_view( - q_dram_unmerged, - make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold), - make_merge_transform_v3_division_mod(make_tuple( - XorLengthFold, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto q_dram_unmerged_xor = transform_tensor_view( - q_dram_merged, - make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold), - make_unmerge_transform(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto q_dram_permuted = transform_tensor_view( - q_dram_unmerged_xor, - make_tuple( - make_xor_transform( - make_tuple(seqlen_q / XorLengthFold, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto q_dram_tmp = transform_tensor_view( - q_dram_permuted, - make_tuple( - make_pass_through_transform(seqlen_q / XorLengthFold), - make_unmerge_transform( - make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_view( - q_dram_tmp, - make_tuple( - make_merge_transform_v3_division_mod( - make_tuple(seqlen_q / XorLengthFold, number{})), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - const auto q_dram_unmerged = transform_tensor_view( - q_dram_pad, - make_tuple( - make_pass_through_transform(seqlen_q), - make_unmerge_transform(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto q_dram_permuted = transform_tensor_view( - q_dram_unmerged, - make_tuple( - make_xor_transform(make_tuple(seqlen_q, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_view( - q_dram_permuted, - make_tuple( - make_pass_through_transform(seqlen_q), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - } - else - { - return pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - - const auto make_k_dram = [&](const KDataType* data, index_t height) { - const auto k_dram_naive = make_naive_tensor_view( - data, // will update this pointer if using paged-kvcache - make_tuple(height, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), - number{}, - number<1>{}); - - const auto k_dram_pad = pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr index_t LDSLayerSize = 256 / sizeof(KDataType); - constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); - - if constexpr(XorLengthFold > 1) - { - const auto k_dram_unmerged = transform_tensor_view( - k_dram_pad, - make_tuple(make_unmerge_transform( - make_tuple(height / XorLengthFold, XorLengthFold)), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto k_dram_merged = transform_tensor_view( - k_dram_unmerged, - make_tuple(make_pass_through_transform(height / XorLengthFold), - make_merge_transform_v3_division_mod(make_tuple( - XorLengthFold, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto k_dram_unmerged_xor = transform_tensor_view( - k_dram_merged, - make_tuple(make_pass_through_transform(height / XorLengthFold), - make_unmerge_transform(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto k_dram_permuted = transform_tensor_view( - k_dram_unmerged_xor, - make_tuple( - make_xor_transform( - make_tuple(height / XorLengthFold, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto k_dram_tmp = transform_tensor_view( - k_dram_permuted, - make_tuple( - make_pass_through_transform(height / XorLengthFold), - make_unmerge_transform(make_tuple( - number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_view( - k_dram_tmp, - make_tuple( - make_merge_transform_v3_division_mod( - make_tuple(height / XorLengthFold, number{})), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - const auto k_dram_unmerged = transform_tensor_view( - k_dram_pad, - make_tuple( - make_pass_through_transform(height), - make_unmerge_transform(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto k_dram_permuted = transform_tensor_view( - k_dram_unmerged, - make_tuple( - make_xor_transform(make_tuple( - height, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_view( - k_dram_permuted, - make_tuple( - make_pass_through_transform(height), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - }; - const auto k_dram = [&]() { - { - return make_k_dram(k_ptr, kargs.seqlen_k); - } - }(); - - const auto make_v_dram = [&](const VDataType* data, index_t length) { - const auto v_dram_naive = make_naive_tensor_view( - data, // will update this pointer if using paged-kvcache - make_tuple(length, kargs.hdim_v), - make_tuple(kargs.hdim_v, 1), - number{}, - number<1>{}); - - // TODO: Add kVHeadDim - constexpr index_t XorGroupSize = - FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); - - const auto v_dram_pad = pad_tensor_view( - v_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr index_t LDSLayerSize = 256 / sizeof(VDataType); - constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); - - if constexpr(XorLengthFold > 1) - { - const auto v_dram_unmerged = transform_tensor_view( - v_dram_pad, - make_tuple(make_unmerge_transform( - make_tuple(length / XorLengthFold, XorLengthFold)), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto v_dram_merged = transform_tensor_view( - v_dram_unmerged, - make_tuple(make_pass_through_transform(length / XorLengthFold), - make_merge_transform_v3_division_mod(make_tuple( - XorLengthFold, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto v_dram_unmerged_xor = transform_tensor_view( - v_dram_merged, - make_tuple( - make_pass_through_transform(length / XorLengthFold), - make_unmerge_transform(make_tuple(number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto v_dram_permuted = transform_tensor_view( - v_dram_unmerged_xor, - make_tuple( - make_xor_transform(make_tuple(length / XorLengthFold, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto v_dram_tmp = transform_tensor_view( - v_dram_permuted, - make_tuple(make_pass_through_transform(length / XorLengthFold), - make_unmerge_transform(make_tuple( - number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_view( - v_dram_tmp, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(length / XorLengthFold, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - const auto v_dram_unmerged = transform_tensor_view( - v_dram_pad, - make_tuple(make_pass_through_transform(length), - make_unmerge_transform( - make_tuple(number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto v_dram_permuted = transform_tensor_view( - v_dram_unmerged, - make_tuple(make_xor_transform(make_tuple( - length, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_view( - v_dram_permuted, - make_tuple(make_pass_through_transform(length), - make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - }; - - const auto v_dram = [&]() { - { - return make_v_dram(v_ptr, kargs.seqlen_k); - } - }(); - - auto q_dram_window = make_tile_window( - q_dram, - [&]() { - if constexpr(FmhaPipeline::kQLoadOnce) - return make_tuple(number{}, - number{}); - else - return make_tuple(number{}, number{}); - }(), - {i_m0, 0}); - - auto k_dram_window = make_tile_window( - k_dram, - make_tuple(number{}, number{}), - {0, 0}); - - auto v_dram_window = make_tile_window( - v_dram, - make_tuple(number{}, number{}), - {0, 0}); - - /// FIXME: Before C++20, capturing structured binding variables are not supported. - /// Remove following copy capture of the 'i_nhead' if in C++20 - const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto bias_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - const BiasDataType* bias_ptr = - reinterpret_cast(kargs.bias_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_bias + - batch_offset_bias; - - const auto bias_dram = [&]() { - const auto bias_dram_naive = - make_naive_tensor_view( - bias_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_bias, 1), - number{}, - number<1>{}); - - return pad_tensor_view(bias_dram_naive, - bias_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(bias_dram_window_lengths); - } - }(); - - // lse acc - auto lse_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto lse_dram_window_lengths = make_tuple(number{}); - if constexpr(kStoreLSE) - { - LSEDataType* lse_ptr = - reinterpret_cast(kargs.lse_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_lse + - batch_offset_lse; - - const auto lse_dram = [&] { - const auto lse_dram_naive = [&] { - { - return make_naive_tensor_view( - lse_ptr, - make_tuple(kargs.seqlen_q), - make_tuple(1), - number<1>{}, - number<1>{}); - } - }(); - return pad_tensor_view( - lse_dram_naive, lse_dram_window_lengths, sequence{}); - }(); - - return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); - } - else - { - return make_null_tile_window(lse_dram_window_lengths); - } - }(); - - FmhaMask mask = [&]() { - if constexpr(kHasMask) - return ck_tile::make_generic_attention_mask_from_lr_window( - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); - else - return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; - }(); - - // WA i_batch capture structure binding before c++20 - auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - // data loading, shared by entire wg - // TODO: how to use s_read? - SaccDataType slope = - *(reinterpret_cast(kargs.alibi_slope_ptr) + - i_batch_ * kargs.alibi_slope_stride + i_nhead_); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; -#endif - if constexpr(kHasMask) - { - return make_alibi_from_lr_mask( - slope, - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type); - } - else - { - return Alibi{ - slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; - } - } - else - { - return EmptyPositionEncoding{}; - } - }(); - - auto o_acc_tile = [&]() { - if constexpr(PrefillCase) - { - // allocate double lds - // add __restrict__ here to avoid aliasing - __shared__ char smem_ptrk0 - [FmhaPipeline::Policy::template GetSmemSizeK()]; - __shared__ char smem_ptrk1 - [FmhaPipeline::Policy::template GetSmemSizeK()]; - __shared__ char smem_ptrv0[FmhaPipeline::Policy::template GetSmemSizeV< - typename FmhaPipeline::Problem>()]; - __shared__ char smem_ptrv1[FmhaPipeline::Policy::template GetSmemSizeV< - typename FmhaPipeline::Problem>()]; - - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - smem_ptrk0, - smem_ptrk1, - smem_ptrv0, - smem_ptrv1); - } - else - { - __shared__ char smem_ptr[GetSmemSize()]; - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - smem_ptr); - } - }(); - - // Oacc DRAM and Oacc DRAM window - auto o_dram = [&] { - const auto o_dram_naive = [&] { - { - return make_naive_tensor_view( - o_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.stride_o, 1), - number{}, - number<1>{}); - } - }(); - - return pad_tensor_view( - o_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - auto o_dram_window = make_tile_window( - o_dram, + return pad_tensor_view( + o_dram_naive, make_tuple(number{}, number{}), - {i_m0, i_n1}); + sequence{}); + }(); - EpiloguePipeline{}(o_dram_window, o_acc_tile); - } + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index f6a20c5cb5..aa2ec99590 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1038,7 +1038,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1096,7 +1096,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1190,7 +1190,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1249,7 +1249,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1344,7 +1344,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1379,7 +1379,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1490,7 +1490,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1589,7 +1589,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1623,7 +1623,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1667,7 +1667,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1718,7 +1718,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); pt_out.set_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), pt_warp_tensor.get_thread_buffer()); }); @@ -1768,7 +1768,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); dst_out.set_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), dst_warp_tensor.get_thread_buffer()); }); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index 45a1c8f4b8..cf70dff63f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -11,7 +11,6 @@ enum class BlockFmhaPipelineEnum QRKSVS = 0, QRKSVS_ASYNC, QSKSVS, - QRKSVS_ASYNC_TRLOAD, }; template @@ -33,10 +32,4 @@ struct BlockFmhaPipelineEnumToStr static constexpr const char* name = "qs"; }; -template <> -struct BlockFmhaPipelineEnumToStr -{ - static constexpr const char* name = "qr_async_trload"; -}; - } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 86ac713b6f..20b30b7417 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -22,7 +22,6 @@ template struct BlockFmhaPipelineProblem { @@ -47,7 +46,6 @@ struct BlockFmhaPipelineProblem static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr bool kUseTrLoad = kUseTrLoad_; // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp deleted file mode 100644 index 39d8814692..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ /dev/null @@ -1,1177 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp" -#include "ck_tile/ops/reduce/block/block_reduce.hpp" - -namespace ck_tile { - -// This pipeline is qkv all located in LDS -template -struct BlockFmhaPipelineQRKSVSAsyncTrload -{ - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using RandValOutputDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using AttentionVariant = remove_cvref_t; - using FmhaMask = remove_cvref_t; - - using BlockFmhaShape = remove_cvref_t; - using VLayout = remove_cvref_t; - static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once - static_assert(kQLoadOnce == Policy::QLoadOnce); - - static constexpr index_t kBlockSize = Problem::kBlockSize; - - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kN1 = BlockFmhaShape::kN1; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; - static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; - static constexpr index_t kNWarp = BlockFmhaShape::Gemm0BlockWarps::at(I1); - static constexpr index_t kNXdl = BlockFmhaShape::Gemm0WarpTile::at(I1); - - static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - - // static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && - // Problem::kPadHeadDimV == true); - - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = - Problem::kPadHeadDimQ; // support multiple of vector(like 8x) - static constexpr bool kPadHeadDimV = - Problem::kPadHeadDimV; // support multiple of vector(like 8x) - - static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; - static constexpr bool kHasDropout = Problem::kHasDropout; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kHasUnevenSplits = true; - - static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && - (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || - !kHasLogitsSoftCap)) || - (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); - - // last dimension vector length used to create tensor view(and decide buffer_load vector length) - // ... together with tensor distribution. tensor dist should able to overwrite this - static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); - static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); - static constexpr index_t kAlignmentV = []() { - if constexpr(std::is_same_v) - return Policy::template GetAlignmentV(); - else - return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); - }(); - - static constexpr index_t kAlignmentOacc = Policy::template GetAlignmentO(); - - static constexpr index_t kAlignmentBias = - kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); - - static constexpr index_t kBlockPerCu = []() { - if constexpr(Problem::kBlockPerCu != -1) - return Problem::kBlockPerCu; - else - { - if constexpr(kQKHeaddim <= 32) - { - return 2; - } - else if constexpr(kQKHeaddim <= 64) - { - return 3; - } - else if constexpr(kQKHeaddim <= 128) - { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || kM0 >= 256) - return 1; - else - return 2; - } - else if constexpr(kQKHeaddim <= 256) - { - return 1; - } - else - { - return 1; - } - } - }(); - - static constexpr const char* name = "qr_async_trload"; - - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - // Decode - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile - FmhaMask mask, - PositionEncoding position_encoding, - float scale_s, - void* smem_ptr) const - { - static_assert( - std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[I0] && - kSubQKHeaddim == QDramBlockWindowTmp{}.get_window_lengths()[I1] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[I0] && - kK0 == KDramBlockWindowTmp{}.get_window_lengths()[I1] && - kN1 == VDramBlockWindowTmp{}.get_window_lengths()[I0] && - kK1 == VDramBlockWindowTmp{}.get_window_lengths()[I1] && - kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I0] && - kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I1], - "wrong!"); - ignore = bias_dram_block_window_tmp; - ignore = position_encoding; - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); - - using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); - auto s_acc = SaccBlockTileType{}; - - // reduction function for softmax - const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; - const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; - - using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); - - auto o_acc = OaccBlockTileType{}; - - // infer Sacc, S, P, M, L, Oacc type - using SBlockTileType = decltype(cast_tile(o_acc)); - - using MLBlockTileType = decltype(block_tile_reduce( - SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); - - // init M, L - auto m = MLBlockTileType{}; - auto l = MLBlockTileType{}; - - clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); - - const auto q_origin = q_dram_block_window_tmp.get_window_origin(); - const auto [logical_seqlen_k_start, logical_seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(I0), number{}, number{}); - - // check early exit if no work to do - if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) - { - const index_t logical_num_total_loop = - integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0); - if(logical_num_total_loop <= 0) - { - if constexpr(kStoreLSE) - { - auto lse_acc = - make_static_distributed_tensor(m.get_tile_distribution()); - - set_tile(lse_acc, -numeric::infinity()); - - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } - } - - // Note: here occ are all cleard, return it - // Note: q loaded but no fence, ignore it. - return o_acc; - } - } - - // Q tile in LDS - auto q_dram_window = make_tile_window( - q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution()); - - auto q_lds_write_view = make_tensor_view( - static_cast(smem_ptr), Policy::template MakeQLdsBlockDescriptor()); - - auto q_lds_read_view = make_tensor_view( - static_cast(smem_ptr), - Policy::template MakeQLdsBlockDescriptor()); - - auto q_lds_store_window = - make_tile_window(q_lds_write_view, - Policy::template MakeQLdsBlockDescriptor().get_lengths(), - {0, 0}); - - auto q_lds_read_window = - make_tile_window(q_lds_read_view, - Policy::template MakeQLdsBlockDescriptor().get_lengths(), - {0, 0}, - Policy::template MakeQRegTileDistribution()); - - async_load_tile(q_lds_store_window, q_dram_window); - - // K tile in LDS - const index_t physical_seqlen_k_start = logical_seqlen_k_start; - const index_t physical_seqlen_k_end = logical_seqlen_k_end; - // make sure the first tile is completely located in page-block (page-block size should be - // divisible by kN0) - // relationship between each *_start variables: aligned_physical_seqlen_k_start <= - // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start - const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start; - - auto k_dram_window = make_tile_window( - k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); - - auto k_lds_write_view = make_tensor_view( - static_cast(smem_ptr), Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_read_view = make_tensor_view( - static_cast(smem_ptr), - Policy::template MakeKLdsBlockDescriptor()); - - auto k_lds_write_window = - make_tile_window(k_lds_write_view, - Policy::template MakeKLdsBlockDescriptor().get_lengths(), - {0, 0}); - auto k_lds_read_window = - make_tile_window(k_lds_read_view, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeKRegTileDistribution()); - - // S tile in LDS - auto s_lds = make_tensor_view( - reinterpret_cast(reinterpret_cast(smem_ptr) + - Policy::template GetSmemSizeK()), - Policy::template MakeSLdsBlockDescriptor()); - auto s_write_lds_window = make_tile_window( - s_lds, Policy::template MakeSLdsBlockDescriptor().get_lengths(), {0, 0}); - auto s_read_lds_window = - make_tile_window(s_lds, - Policy::template MakeSLdsBlockDescriptor().get_lengths(), - {0, 0}, - Policy::template MakeSRegTileDistribution()); - - // V tile in LDS - auto v_dram_window = make_tile_window( - v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); - - auto v_lds_write_view = make_tensor_view( - reinterpret_cast(static_cast(smem_ptr) + - Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeS()), - Policy::template MakeVLdsBlockDescriptor()); - auto v_lds_read_view = make_tensor_view( - reinterpret_cast(static_cast(smem_ptr) + - Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeS()), - Policy::template MakeVLdsBlockDescriptor()); - auto v_lds_write_window = - make_tile_window(v_lds_write_view, - Policy::template MakeVLdsBlockDescriptor().get_lengths(), - {0, 0}); - - auto v_lds_read_window = - make_tile_window(v_lds_read_view, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeVRegTileDistribution()); - - block_sync_lds_direct_load<0>(); - auto q_tile = load_tile(q_lds_read_window); - - const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); - - index_t i_total_loops = 0; - constexpr index_t k0_loops = kQKHeaddim / kK0; - constexpr index_t k1_loops = kN0 / kK1; - - static_assert(1 <= k0_loops); - static_assert(1 <= k1_loops); - - block_sync_lds(); - async_load_tile(k_lds_write_window, k_dram_window); - - constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access(); - constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access(); - - do - { - block_sync_lds(); - async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile - - // move V tile windows - move_tile_window(v_dram_window, {kN0, 0}); - - // STAGE 1, QK gemm - clear_tile(s_acc); // initialize C - - if constexpr(1 < k0_loops) - { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - if constexpr(i_k0 == 0) - { - block_sync_lds_direct_load(); - } - else - { - block_sync_lds_direct_load<0>(); - } - - auto k_tile = load_tile(k_lds_read_window); - - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_tile); - - // loop over along the [K]ey head dimension - move_tile_window(k_dram_window, {0, kK0}); - block_sync_lds(); - async_load_tile(k_lds_write_window, k_dram_window); - }); - // move back to the origin - move_tile_window(k_dram_window, {0, -kK0 * (k0_loops - 1)}); - } - - if constexpr(k0_loops == 1) - { - block_sync_lds_direct_load(); - } - else - { - block_sync_lds_direct_load<0>(); - } - - auto k_tile = load_tile(k_lds_read_window); - - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_tile); - - if constexpr(kHasUnevenSplits) - { - if(i_total_loops == (num_total_loop - 1)) - { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); - set_tile_if(s_acc, - -numeric::infinity(), - [&, - physical_seqlen_k_start_ = physical_seqlen_k_start, - physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { - const auto col = k_origin.at(I0) + tile_idx.at(I1); - - { - return physical_seqlen_k_end_ <= col; - } - }); - } - } - - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) - { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); - - bool need_perpixel_check = - mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number{}, number{}); - if(need_perpixel_check) - { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(I0) + tile_idx.at(I0); - const auto col = k_origin.at(I0) + tile_idx.at(I1); - return mask.IsOutOfBound(row, col); - }); - } - } - - // move K tile windows after current status checked - // prefetch next-tile along [K]ey sequence length dimension - move_tile_window(k_dram_window, {kN0, 0}); - - block_sync_lds(); - async_load_tile(k_lds_write_window, k_dram_window); - - // Gemm1 - auto s_new = [&]() { - if constexpr(kNWarp > 1) - { - auto s = cast_tile(s_acc); // S{j} - - store_tile(s_write_lds_window, s); - block_sync_lds(); - return load_tile(s_read_lds_window); - } - else - { - return cast_tile(s_acc); // S{j} - } - }(); - - auto m_local = block_tile_reduce( - s_new, - sequence<1>{}, - f_max, - -numeric::infinity()); // m_local = rowmax(S{j}) - // Set CrossWarp to false will trigger better strategy on gfx950, but will cause - // performance regression because of un-coexecutable packed math, silent it for now - block_tile_reduce_sync( - m_local, f_max, bool_constant{} /*, bool_constant{}*/); - - const auto m_old = m; // m{j-1} - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} - - auto p_compute = make_static_distributed_tensor( - s_new.get_tile_distribution()); // Pcompute{j} - - static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_m == -numeric::infinity() - ? type_convert(0.f) - : raw_m; - } - else - { - return raw_m; - } - }; - - constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); - sweep_tile_span(p_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - auto row_max = scale_s * get_validated_m(m[i_idx]); - sweep_tile_span(p_spans[I1], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); - } - } - }); - }); - - auto rowsum_p = block_tile_reduce( - p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - - block_tile_reduce_sync( - rowsum_p, f_sum, bool_constant{} /*, bool_constant{}*/); - - auto p_tile = make_static_distributed_tensor( - Policy::template MakePRegTileDistribution()); - p_tile.get_thread_buffer() = cast_tile(p_compute).get_thread_buffer(); - - // l{j}, Oacc{j} - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); - } - } - }(); - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[I1], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - o_acc(i_j_idx) *= tmp; - }); - }); - - block_sync_lds_direct_load(); - - auto v_tile = load_tile_transpose(v_lds_read_window); - - if constexpr(1 < k1_loops) - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - gemm_1(o_acc, - get_slice_tile(p_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}), - v_tile); - - // loop over along the [V]alue Sequence length - move_tile_window(v_lds_read_window, {kK1, 0}); - v_tile = load_tile_transpose(v_lds_read_window); - }); - // move back to the origin - move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); - } - - gemm_1(o_acc, - get_slice_tile(p_tile, - sequence<0, (k1_loops - 1) * kK1>{}, - sequence{}), - v_tile); - - } while(++i_total_loops < num_total_loop); - - if constexpr(kStoreLSE) - { - // store lse acc - auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - - constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); - sweep_tile_span(lse_acc_spans[I0], [&, m_ = m, l_ = l](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); - } - else - { - lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); - } - } - }); - - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } - } - - // finally, O - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - - sweep_tile_span(o_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; - } - else - return 1 / l[i_idx]; - }(); - sweep_tile_span(o_spans[I1], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; - }); - }); - - return o_acc; - } - - // Prefill, double lds - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile - FmhaMask mask, - PositionEncoding position_encoding, - float scale_s, - void* __restrict__ smem_ptrk0, - void* __restrict__ smem_ptrk1, - void* __restrict__ smem_ptrv0, - void* __restrict__ smem_ptrv1) const - { - static_assert( - std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[I0] && - kSubQKHeaddim == QDramBlockWindowTmp{}.get_window_lengths()[I1] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[I0] && - kK0 == KDramBlockWindowTmp{}.get_window_lengths()[I1] && - kN1 == VDramBlockWindowTmp{}.get_window_lengths()[I0] && - kK1 == VDramBlockWindowTmp{}.get_window_lengths()[I1] && - kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I0] && - kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I1], - "wrong!"); - ignore = bias_dram_block_window_tmp; - ignore = position_encoding; - - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); - - using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); - auto s_acc = SaccBlockTileType{}; - - // reduction function for softmax - const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; - const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; - - using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); - - auto o_acc = OaccBlockTileType{}; - - // infer Sacc, S, P, M, L, Oacc type - using SBlockTileType = decltype(cast_tile(o_acc)); - - using MLBlockTileType = decltype(block_tile_reduce( - SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); - - // init M, L - auto m = MLBlockTileType{}; - auto l = MLBlockTileType{}; - - clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); - - const auto q_origin = q_dram_block_window_tmp.get_window_origin(); - const auto [logical_seqlen_k_start, logical_seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(I0), number{}, number{}); - - // check early exit if no work to do - if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) - { - const index_t logical_num_total_loop = - integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0); - if(logical_num_total_loop <= 0) - { - if constexpr(kStoreLSE) - { - auto lse_acc = - make_static_distributed_tensor(m.get_tile_distribution()); - - set_tile(lse_acc, -numeric::infinity()); - - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } - } - - // Note: here occ are all cleard, return it - // Note: q loaded but no fence, ignore it. - return o_acc; - } - } - - // Q tile in LDS - auto q_dram_window = make_tile_window( - q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution()); - - auto q_lds_write_view = make_tensor_view( - static_cast(smem_ptrk0), - Policy::template MakeQLdsBlockDescriptor()); - - auto q_lds_read_view = make_tensor_view( - static_cast(smem_ptrk0), - Policy::template MakeQLdsBlockDescriptor()); - - auto q_lds_store_window = - make_tile_window(q_lds_write_view, - Policy::template MakeQLdsBlockDescriptor().get_lengths(), - {0, 0}); - - auto q_lds_read_window = - make_tile_window(q_lds_read_view, - Policy::template MakeQLdsBlockDescriptor().get_lengths(), - {0, 0}, - Policy::template MakeQRegTileDistribution()); - - async_load_tile(q_lds_store_window, q_dram_window); - block_sync_lds_direct_load<0>(); - auto q_tile = load_tile(q_lds_read_window); - - // K tile in LDS - const index_t physical_seqlen_k_start = logical_seqlen_k_start; - const index_t physical_seqlen_k_end = logical_seqlen_k_end; - // make sure the first tile is completely located in page-block (page-block size should be - // divisible by kN0) - // relationship between each *_start variables: aligned_physical_seqlen_k_start <= - // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start - const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start; - - auto k_dram_window = make_tile_window( - k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); - - auto k_lds_write_view = make_tensor_view( - static_cast(smem_ptrk0), - Policy::template MakeKLdsBlockDescriptor()); - - auto k_lds_read_view = make_tensor_view( - static_cast(smem_ptrk0), - Policy::template MakeKLdsBlockDescriptor()); - - auto k_lds_write_window = - make_tile_window(k_lds_write_view, - Policy::template MakeKLdsBlockDescriptor().get_lengths(), - {0, 0}); - - auto k_lds_read_window = - make_tile_window(k_lds_read_view, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeKRegTileDistribution()); - - // S tile in LDS - auto s_lds = make_tensor_view( - reinterpret_cast(reinterpret_cast(smem_ptrk0) + - Policy::template GetSmemSizeK()), - Policy::template MakeSLdsBlockDescriptor()); - auto s_write_lds_window = make_tile_window( - s_lds, Policy::template MakeSLdsBlockDescriptor().get_lengths(), {0, 0}); - auto s_read_lds_window = - make_tile_window(s_lds, - Policy::template MakeSLdsBlockDescriptor().get_lengths(), - {0, 0}, - Policy::template MakeSRegTileDistribution()); - - // V tile in LDS - auto v_dram_window = make_tile_window( - v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); - - auto v_lds_write_view = make_tensor_view( - reinterpret_cast(static_cast(smem_ptrv0)), - Policy::template MakeVLdsBlockDescriptor()); - - auto v_lds_read_view = make_tensor_view( - reinterpret_cast(static_cast(smem_ptrv0)), - Policy::template MakeVLdsBlockDescriptor()); - - auto v_lds_write_window = - make_tile_window(v_lds_write_view, - Policy::template MakeVLdsBlockDescriptor().get_lengths(), - {0, 0}); - - auto v_lds_read_window = - make_tile_window(v_lds_read_view, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeVRegTileDistribution()); - - // block_sync_lds_direct_load<0>(); - // auto q_tile = load_tile(q_lds_read_window); - - const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); - - index_t i_total_loops = 0; - constexpr index_t k0_loops = kQKHeaddim / kK0; - constexpr index_t k1_loops = kN0 / kK1; - - static_assert(1 <= k0_loops); - static_assert(1 <= k1_loops); - block_sync_lds<0>(); - async_load_tile(k_lds_write_window, k_dram_window); - async_load_tile(v_lds_write_window, v_dram_window); - - move_tile_window(k_dram_window, {kN0, 0}); - k_lds_write_window.set_bottom_tensor_view_data_ptr( - static_cast(smem_ptrk1)); - async_load_tile(k_lds_write_window, k_dram_window); - - constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access(); - constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access(); - - constexpr index_t k_lds_insts = k_lds_read_window.get_num_of_access(); - constexpr index_t v_lds_insts = v_lds_read_window.get_num_of_access(); - - block_sync_lds_direct_load(); - auto k_tile = load_tile(k_lds_read_window); - - __builtin_amdgcn_sched_barrier(0); - - auto mainloop = [&](index_t cur_loop) { - const bool is_even_loop = (cur_loop % 2 == 0); - - auto k_lds_write_ptr = is_even_loop ? static_cast(smem_ptrk0) - : static_cast(smem_ptrk1); - auto k_lds_read_ptr = is_even_loop ? static_cast(smem_ptrk1) - : static_cast(smem_ptrk0); - auto v_lds_write_ptr = is_even_loop ? static_cast(smem_ptrv1) - : static_cast(smem_ptrv0); - auto v_lds_read_ptr = is_even_loop ? static_cast(smem_ptrv0) - : static_cast(smem_ptrv1); - - // move V tile windows - block_sync_lds(); - move_tile_window(v_dram_window, {kN0, 0}); - v_lds_write_window.set_bottom_tensor_view_data_ptr(v_lds_write_ptr); - async_load_tile(v_lds_write_window, v_dram_window); - - // STAGE 1, QK gemm - clear_tile(s_acc); // initialize C - - if constexpr(1 < k0_loops) - { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - // loop over along the [K]ey head dimension - move_tile_window(k_lds_read_window, {0, kK0}); - auto k_tile_switch = load_tile(k_lds_read_window); - - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_tile); - - k_tile = k_tile_switch; - }); - // move back to the origin - move_tile_window(k_lds_read_window, {0, -kK0 * (k0_loops - 1)}); - } - - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_tile); - - block_sync_lds_direct_load(); - v_lds_read_window.set_bottom_tensor_view_data_ptr(v_lds_read_ptr); - auto v_tile = load_tile_transpose(v_lds_read_window); - - if constexpr(kHasUnevenSplits) - { - if(i_total_loops == (num_total_loop - 1)) - { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); - set_tile_if(s_acc, - -numeric::infinity(), - [&, - physical_seqlen_k_start_ = physical_seqlen_k_start, - physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { - const auto col = k_origin.at(I0) + tile_idx.at(I1); - - { - return physical_seqlen_k_end_ <= col; - } - }); - } - } - - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) - { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); - - bool need_perpixel_check = - mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number{}, number{}); - if(need_perpixel_check) - { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(I0) + tile_idx.at(I0); - const auto col = k_origin.at(I0) + tile_idx.at(I1); - return mask.IsOutOfBound(row, col); - }); - } - } - - // Gemm1 - auto s_new = [&]() { - if constexpr(kNWarp > 1) - { - auto s = cast_tile(s_acc); // S{j} - - store_tile(s_write_lds_window, s); - block_sync_lds(); - return load_tile(s_read_lds_window); - } - else - { - return cast_tile(s_acc); // S{j} - } - }(); - - auto m_local = block_tile_reduce( - s_new, - sequence<1>{}, - f_max, - -numeric::infinity()); // m_local = rowmax(S{j}) - block_tile_reduce_sync( - m_local, f_max, bool_constant{} /*, bool_constant{}*/); - - static_for<0, 12, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ - }); - - static_for<0, 4, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ - }); - - const auto m_old = m; // m{j-1} - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} - - auto p_compute = make_static_distributed_tensor( - s_new.get_tile_distribution()); // Pcompute{j} - - static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_m == -numeric::infinity() - ? type_convert(0.f) - : raw_m; - } - else - { - return raw_m; - } - }; - - constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); - sweep_tile_span(p_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - auto row_max = scale_s * get_validated_m(m[i_idx]); - sweep_tile_span(p_spans[I1], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); - } - } - }); - }); - - auto rowsum_p = block_tile_reduce( - p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - - block_tile_reduce_sync( - rowsum_p, f_sum, bool_constant{} /*, bool_constant{}*/); - - auto p_tile = make_static_distributed_tensor( - Policy::template MakePRegTileDistribution()); - p_tile.get_thread_buffer() = cast_tile(p_compute).get_thread_buffer(); - - // l{j}, Oacc{j} - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); - } - } - }(); - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[I1], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - o_acc(i_j_idx) *= tmp; - }); - }); - - block_sync_lds(); - move_tile_window(k_dram_window, {kN0, 0}); - k_lds_write_window.set_bottom_tensor_view_data_ptr(k_lds_write_ptr); - async_load_tile(k_lds_write_window, k_dram_window); - - if constexpr(1 < k1_loops) - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - // loop over along the [V]alue Sequence length - move_tile_window(v_lds_read_window, {kK1, 0}); - auto v_tile_switch = load_tile_transpose(v_lds_read_window); - - gemm_1(o_acc, - get_slice_tile(p_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}), - v_tile); - - v_tile = v_tile_switch; - }); - // move back to the origin - move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); - } - - gemm_1(o_acc, - get_slice_tile(p_tile, - sequence<0, (k1_loops - 1) * kK1>{}, - sequence{}), - v_tile); - - block_sync_lds_direct_load(); - k_lds_read_window.set_bottom_tensor_view_data_ptr(k_lds_read_ptr); - k_tile = load_tile(k_lds_read_window); - - static_for<0, 12, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ - }); - - static_for<0, 4, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ - }); - }; - - do - { - mainloop(i_total_loops); - i_total_loops++; - } while(i_total_loops < num_total_loop); - - if constexpr(kStoreLSE) - { - // store lse acc - auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - - constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); - sweep_tile_span(lse_acc_spans[I0], [&, m_ = m, l_ = l](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); - } - else - { - lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); - } - } - }); - - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } - } - - // finally, O - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - - sweep_tile_span(o_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; - } - else - return 1 / l[i_idx]; - }(); - sweep_tile_span(o_spans[I1], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; - }); - }); - - return o_acc; - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp deleted file mode 100644 index ed22758566..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp +++ /dev/null @@ -1,823 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" -#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp" - -// can remove all bank conflicts, but drop the performance for some cases -// Probably it is limited by compiler optimization. -#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 -namespace ck_tile { -// This pipeline is qkv all located in LDS -struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy - : BlockFmhaPipelineQXKSVSCustomPolicy -{ - using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy; - - template - CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); - - // this should align with MakeQDramTileDistribution() - constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - return min(ElemPerThread, MaxVectorSize); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() - { - using OaccDataType = remove_cvref_t; - - return static_cast(16 / sizeof(OaccDataType)); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType); - - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - return min(ElemPerThread, MaxVectorSize); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; - - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType); - - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - return min(ElemPerThread, MaxVectorSize); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() - { - if constexpr(!BypassLDS) - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); - - constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - - constexpr index_t KPerThread = kMaxVecLoad; - constexpr index_t KThreads = kKPerBlock / KPerThread; - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - else - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - constexpr auto q_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<2, 1>, - sequence<0, 0>>{}; - - constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); - - return q_block_dstr; - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() - { - using KDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = - LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0; - - constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - - constexpr index_t K1 = min(MaxVectorSize, ElemPerThread); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - // Read M first, then K - // This is the same data consume order as BlockGEMM - constexpr auto q_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); - - return q_block_dstr; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() - { - // TODO: this is for 3d layout - using QDataType = remove_cvref_t; - return static_cast(16 / sizeof(QDataType)); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - - constexpr index_t kKPack = GetSmemKPackQ(); - - constexpr auto q_lds_block_desc = [&]() { - if constexpr(Xor) - { -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::QDataType); - constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock; - - if constexpr(XorLengthFold > 1) - { - constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor( - q_lds_block_desc_naive, - make_tuple( - make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - constexpr auto q_lds_block_desc_tmp = transform_tensor_descriptor( - q_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_descriptor( - q_lds_block_desc_tmp, - make_tuple( - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor( - make_tuple( - number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor( - q_lds_block_desc_naive, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_descriptor( - q_lds_block_desc_permuted, - make_tuple(make_pass_through_transform(number{}), - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - } - else - { - return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number{}, - number<1>{}); - } - }(); - - return q_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = - LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0; - - constexpr index_t kKPack = GetSmemKPackK(); - - constexpr auto k_lds_block_desc = [&]() { - if constexpr(Xor) - { -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::KDataType); - constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock; - - if constexpr(XorLengthFold > 1) - { - constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor( - k_lds_block_desc_naive, - make_tuple( - make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - constexpr auto k_lds_block_desc_tmp = transform_tensor_descriptor( - k_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_descriptor( - k_lds_block_desc_tmp, - make_tuple( - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor( - make_tuple( - number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor( - k_lds_block_desc_naive, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_descriptor( - k_lds_block_desc_permuted, - make_tuple(make_pass_through_transform(number{}), - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - } - else - { - return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number{}, - number<1>{}); - } - }(); - - return k_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; - - constexpr index_t kKPack = GetSmemKPackV(); - - constexpr auto v_lds_block_desc = [&]() { - if constexpr(Xor) - { - constexpr auto XorGroupSize = - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); - -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::VDataType); - constexpr auto XorLengthFold = LDSLayerSize / kNPerBlock; - - if constexpr(XorLengthFold > 1) - { - constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor( - v_lds_block_desc_naive, - make_tuple( - make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - constexpr auto v_lds_block_desc_tmp = transform_tensor_descriptor( - v_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_descriptor( - v_lds_block_desc_tmp, - make_tuple( - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{})), - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor( - v_lds_block_desc_naive, - make_tuple(make_xor_transform(make_tuple( - number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_descriptor( - v_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - } - else - { - return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number{}, - number<1>{}); - } - }(); - - return v_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() - { - using GemmProblem = - BlockGemmProblem, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>, - GemmLoopOrder::MNK>; - - using WarpGemm = - WarpGemmMfmaDispatcher{}), - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), - true>; - - using BlockGemmPolicy = - BlockGemmARegBRegCRegV1CustomPolicy; - - return BlockGemmARegBRegCRegV1{}; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetPVBlockGemm() - { - using GemmProblem = - BlockGemmProblem, - typename Problem::BlockFmhaShape::Gemm1BlockWarps, - typename Problem::BlockFmhaShape::Gemm1WarpTile>, - GemmLoopOrder::KMN>; - - using WarpGemm = WarpGemmMfmaDispatcher< - typename Problem::PDataType, - typename Problem::VDataType, - typename Problem::OaccDataType, - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), - true, - false, - false, - ((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 && - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) || - (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 && - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16)) - ? WGAttrNumAccessEnum::Double - : WGAttrNumAccessEnum::Single>; - - using BlockGemmPolicy = - BlockGemmARegBRegCRegV1CustomPolicy; - - return BlockGemmARegBRegCRegV1{}; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKRegTileDistribution() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - // Read N first, then K - // This is the same data consume order as BlockGEMM - constexpr auto k_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); - - return k_block_dstr; - } - - template - CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; - - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType); - - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - - constexpr index_t NPerThread = kMaxVecLoad; - constexpr index_t NThreads = kNPerBlock / NPerThread; - constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t KPerThread = kKPerBlock / (KThreadPerWarp * NumWarps); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakePRegTileDistribution() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - // Read M first, then K - // This is the same data consume order as BlockGEMM - constexpr auto p_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<2, 1>, - sequence<0, 0>>{}; - - constexpr auto p_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - p_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - constexpr auto p_block_dstr = make_static_tile_distribution(p_block_dstr_encode); - - return p_block_dstr; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeVRegTileDistribution() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - // Read N first, then K - // This is the same data consume order as BlockGEMM - constexpr auto v_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<2, 1>, - sequence<0, 0>>{}; - - constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - constexpr auto v_block_dstr = - make_static_tile_distribution(typename InputTileDistributionTraits< - decltype(v_block_dstr_encode), - typename Problem::VDataType>::TransposedDstrEncode{}); - - return v_block_dstr; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS() - { - using SDataType = remove_cvref_t; - return static_cast(16 / sizeof(SDataType)); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeSLdsBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kNPack = GetSmemNPackS(); - - constexpr auto s_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number<(kMPerBlock + 1) * kNPack>{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto s_lds_block_desc = transform_tensor_descriptor( - s_lds_block_desc_0, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return s_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeSRegTileDistribution() - { - using BlockGemm = remove_cvref_t())>; - - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - // static_assert(MWarp == 1, "Check failed!"); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kTileK = Problem::BlockFmhaShape::kN0; - - // K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm - constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K1 = kKPerBlock / (K2 * K3); - constexpr index_t K0 = kTileK / kKPerBlock; - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - constexpr auto s2_block_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<2, 2>>, - sequence<1, 2, 2, 2>, - sequence<0, 0, 1, 3>>{}; - - constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding); - - return s2_block_dstr; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() - { - return MakeQLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::QDataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() - { - return MakeKLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::KDataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() - { - return MakeVLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::VDataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeS() - { - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); - - return NWarp > 1 ? MakeSLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::SaccDataType) - : 0; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() - { - // Alignment on gfx950 is 1280 Bytes - // Alignment before gfx950 is 512 Bytes. - return max(GetSmemSizeQ(), - GetSmemSizeK() + GetSmemSizeS() + GetSmemSizeV()); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index e2cea97f9a..3489d6f9a1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -383,31 +383,23 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t kMaxVecLoad = - min(total_pixels, static_cast(16 / sizeof(VDataType))); - - return kMaxVecLoad; + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); } template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { - using VLayout = remove_cvref_t; - using VDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t kMaxVecLoad = - min(total_pixels, static_cast(16 / sizeof(VDataType))); - + using VLayout = remove_cvref_t; + using VDataType = remove_cvref_t; if constexpr(std::is_same_v) { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t kMaxVecLoad = + min(total_pixels, static_cast(16 / sizeof(VDataType))); constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) @@ -418,7 +410,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - using WarpGemm = typename Traits::WarpGemm; - using BlockGemmShape = typename Traits::BlockGemmShape; - static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder; + using WarpGemm = typename Traits::WarpGemm; + using BlockGemmShape = typename Traits::BlockGemmShape; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -89,36 +86,17 @@ struct BlockGemmARegBRegCRegV1 } else { - if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) - { - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<2, 1>, - sequence<0, 0>>{}; + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; - } - else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) - { - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; - } + return a_block_dstr_encode; } } @@ -140,33 +118,17 @@ struct BlockGemmARegBRegCRegV1 } else { - if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) - { - constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<2, 1>, - sequence<0, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - return b_block_dstr_encode; - } - else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) - { - constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - return b_block_dstr_encode; - } + return b_block_dstr_encode; } } @@ -251,82 +213,40 @@ struct BlockGemmARegBRegCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) - { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - // read C warp tensor from C block tensor - using c_iter_idx = std::conditional_t, - sequence>; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - }); - } - else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) - { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + // read C warp tensor from C block tensor + using c_iter_idx = std:: + conditional_t, sequence>; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); - } + }); } CK_TILE_DEVICE static constexpr auto MakeCBlockTile() diff --git a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp index d0be065fc9..fd5211a59a 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { @@ -14,8 +13,7 @@ template + index_t NumWaveGroups_ = 1> struct BlockGemmProblem { using ADataType = remove_cvref_t; @@ -23,9 +21,8 @@ struct BlockGemmProblem using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t NumWaveGroups = NumWaveGroups_; - static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t NumWaveGroups = NumWaveGroups_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index b3c86b9456..b18bf603a9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -39,12 +39,6 @@ enum struct TailNumber Full, }; -enum struct GemmLoopOrder -{ - KMN, - MNK, -}; - } // namespace ck_tile inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index c628614b54..52bd07c9e2 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -14,11 +14,10 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> struct GemmPipelineProblemBase { using Traits = remove_cvref_t; @@ -46,10 +45,9 @@ struct GemmPipelineProblemBase static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadK = Traits::kPadK; - static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; - static constexpr auto Scheduler = GemmPipelineScheduler::Default; - static constexpr index_t VectorLoadSize = Traits::_VectorSize; - static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; + static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; + static constexpr auto Scheduler = GemmPipelineScheduler::Default; + static constexpr index_t VectorLoadSize = Traits::_VectorSize; // In the base situation, the Preshuffle setting should be false. static constexpr bool Preshuffle = false; @@ -169,11 +167,10 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> using GemmPipelineProblem = GemmPipelineProblemBase; + VectorSizeB_>; template + GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, + bool HasHotLoop_ = true, + TailNumber TailNum_ = TailNumber::Full, + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> struct UniversalGemmPipelineProblem { using Traits = remove_cvref_t; @@ -229,9 +224,8 @@ struct UniversalGemmPipelineProblem static constexpr auto Scheduler = Scheduler_; static constexpr bool Preshuffle = Traits::Preshuffle; - static constexpr index_t VectorSizeA = VectorSizeA_; - static constexpr index_t VectorSizeB = VectorSizeB_; - static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; + static constexpr index_t VectorSizeA = VectorSizeA_; + static constexpr index_t VectorSizeB = VectorSizeB_; static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index d1deaf9e0e..fb191d565d 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -104,10 +104,6 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = 1>>; #endif -using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution = - WarpGemmImpl>>; - #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>; #endif -using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution = - WarpGemmImpl>>; - #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp16 2:4 structural sparsity // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity @@ -76,8 +74,6 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 7a10d1fa56..434be9f84a 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -14,14 +14,10 @@ namespace ck_tile { * Y dim must have at least one dim not been reduced */ // synchronize reduce result (cross lane reduction and broadcast on replicated dimension) -template +template CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, const ReduceFunc& reduce_func, - bool_constant = {}, - bool_constant = {}) + bool_constant = {}) { using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using DstrEncode = typename Dstr::DstrEncode; @@ -60,24 +56,14 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, // reduction sweep forward static_for<0, nstage, 1>{}([&](auto istage) { - if constexpr(CrossWarp) - { - constexpr index_t lid_delta = - lid_over_rid_derivative * (1 << (nstage - istage - 1)); + constexpr index_t lid_delta = + lid_over_rid_derivative * (1 << (nstage - istage - 1)); - // pull data from remote lane - const auto v_remote = warp_shuffle_down(v_local, lid_delta); + // pull data from remote lane + const auto v_remote = warp_shuffle_down(v_local, lid_delta); - // reduce - v_local = reduce_func(v_local, v_remote); - } - else - { - // pull data from remote lane - const auto v_swapped_regs = warp_shuffle_down_pair(v_local); - // reduce - v_local = reduce_func(v_swapped_regs.at(0), v_swapped_regs.at(1)); - } + // reduce + v_local = reduce_func(v_local, v_remote); }); } }); From 20288caa2f20082187a5e0d39d28907e1baf766e Mon Sep 17 00:00:00 2001 From: slippedJim Date: Wed, 13 Aug 2025 00:23:40 +0800 Subject: [PATCH 4/9] remove bad pipeline codegen (#2673) --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 2 +- example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 269af4e6a7..471486419a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -507,7 +507,7 @@ class KernelComponentFactory: (64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 0e4ac44d45..b2d962cd74 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -638,7 +638,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -657,7 +657,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d '64' : FmhaFwdSplitKVCombineTileSize(32, -1), '96' : FmhaFwdSplitKVCombineTileSize(32, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, -1), - '160' : FmhaFwdSplitKVCombineTileSize(32, -1), + # '160' : FmhaFwdSplitKVCombineTileSize(32, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } elif dtype == 'fp8' or dtype == 'bf8': From bbf41b27f2e533c431edda39850af1a8630f483f Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 12 Aug 2025 10:23:08 -0700 Subject: [PATCH 5/9] fix builds with mainline/staging compilers (#2674) --- Jenkinsfile | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 590ee92e90..619f15d624 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -460,7 +460,9 @@ def buildHipClangJob(Map conf=[:]){ } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ - dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " + // the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with + // newer clang22 compilers and running with older hip runtima libraries + dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 " } def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') @@ -518,7 +520,9 @@ def Build_CK(Map conf=[:]){ } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ - dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " + // the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with + // newer clang22 compilers and running with older hip runtima libraries + dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 " } if(params.BUILD_LEGACY_OS){ dockerOpts = dockerOpts + " --env LD_LIBRARY_PATH='/opt/Python-3.8.13/lib' " From 0856b3f4a29bd454fb8a9cef3d8776fb84e38119 Mon Sep 17 00:00:00 2001 From: joyeamd Date: Wed, 13 Aug 2025 03:33:56 +0800 Subject: [PATCH 6/9] [CK_TILE]fix ck_tile's moe_sorting example in gfx11 (#2667) * fix ck_tile's moe_sorting example in gfx11 * fix clang format --------- Co-authored-by: illsilin_amdeng --- .../flatmm_32x512x128_1x4x1_16x16x32.hpp | 100 ++++++++++-------- 1 file changed, 58 insertions(+), 42 deletions(-) diff --git a/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp index 23c4ad583e..21ca470222 100644 --- a/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp +++ b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp @@ -63,48 +63,15 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8 static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4 - static CK_TILE_DEVICE constexpr auto MakeCBlockDist() + private: + template + struct LdsStoreDescSelector; + + template + struct LdsStoreDescSelector= WarpSize)>> { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2, 1>, // !! note here is different - sequence<0, 0>>{}; - - using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - return c_block_dstr; - } - - static CK_TILE_DEVICE constexpr auto MakeCBlockTile() - { - using CDataType = float; - constexpr auto c_block_dstr = MakeCBlockDist(); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - return c_block_tensor; - } - - CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A() - { - // A async->LDS - // constexpr index_t Block_M = Problem::BlockShape::Block_M0; - // constexpr index_t Block_K = Problem::BlockShape::Block_K0; - // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; - constexpr index_t WarpSize = ck_tile::get_warp_size(); - // constexpr index_t NumWarps = Problem::BlockShape::NumWarps; - - constexpr index_t KPack_ = 8; // GetSmemKPack_A(); // LDS - constexpr index_t KVector = 2; // GetAlignment_A(); // async copy 1 dword - constexpr index_t KPad = KPack_; // pad between warps - - static_assert(Block_K % KVector == 0); - constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K - if constexpr(LanesPerK >= WarpSize) + template + static CK_TILE_HOST_DEVICE constexpr auto MakeDesc() { // need multiple waves to load K static_assert(LanesPerK % WarpSize == 0); @@ -143,7 +110,13 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 return lds_block_desc_issues_warps_lanes; } } - else + }; + + template + struct LdsStoreDescSelector> + { + template + static CK_TILE_HOST_DEVICE constexpr auto MakeDesc() { // lanes within a wave load different M but same K static_assert(WarpSize % LanesPerK == 0); @@ -175,6 +148,49 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 return lds_block_desc_issues_warps_lanes; } + }; + + public: + static CK_TILE_DEVICE constexpr auto MakeCBlockDist() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, // !! note here is different + sequence<0, 0>>{}; + + using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + return c_block_dstr; + } + + static CK_TILE_DEVICE constexpr auto MakeCBlockTile() + { + using CDataType = float; + constexpr auto c_block_dstr = MakeCBlockDist(); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A() + { + // A async->LDS + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack_ = 8; // GetSmemKPack_A(); // LDS + constexpr index_t KVector = 2; // GetAlignment_A(); // async copy 1 dword + constexpr index_t KPad = KPack_; // pad between warps + + static_assert(Block_K % KVector == 0); + constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K + + return LdsStoreDescSelector:: + template MakeDesc(); } // template From 30dafe82810bd49a186149007f33ebbf120084de Mon Sep 17 00:00:00 2001 From: Geo Min Date: Tue, 12 Aug 2025 14:13:01 -0700 Subject: [PATCH 7/9] [TheRock CI] Adding TheRock CI gate check (#2648) * Adding initial TheRock CI * Adding composable kernel link * Adding correct repo for rocm-libraries * Adding entire rocm-libraries checkout * Adding correct flag * Adding correct flag for fetch sources * Fixing git health * Removing patch * Removing patching * Removing manual check * PR comments * testing without dist * Removing test branch * PR comments * PR comments * PR comment * Adding test_runs_on --- .github/workflows/therock-ci-linux.yml | 128 ++++++++++++++++++++ .github/workflows/therock-ci.yml | 50 ++++++++ .github/workflows/therock-test-packages.yml | 76 ++++++++++++ 3 files changed, 254 insertions(+) create mode 100644 .github/workflows/therock-ci-linux.yml create mode 100644 .github/workflows/therock-ci.yml create mode 100644 .github/workflows/therock-test-packages.yml diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml new file mode 100644 index 0000000000..645a91c030 --- /dev/null +++ b/.github/workflows/therock-ci-linux.yml @@ -0,0 +1,128 @@ +name: TheRock CI Linux + +on: + workflow_call: + inputs: + cmake_options: + type: string + amdgpu_families: + type: string + test_runs_on: + type: string + +permissions: + contents: read + +jobs: + therock-build-linux: + name: Build Linux Packages + runs-on: azure-linux-scale-rocm + permissions: + id-token: write + container: + image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:044b113562629f4bd2ec5d2e64b32eee11562d48fb1a75d7493daec9dd8d8292 + env: + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + TEATIME_FORCE_INTERACTIVE: 0 + steps: + - name: Checkout composable_kernel repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Checkout TheRock repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: "ROCm/TheRock" + ref: ec1c2ef4f2636bce7733fd8c95e1dbb6692c8a57 + path: "TheRock" + + - name: Runner Health Settings + run: | + df -h + cmake --version + echo "Installed Python versions:" + ls -d /opt/python + echo "python: $(which python), python3: $(which python3)" + echo "Git version: $(git --version)" + git config --global --add safe.directory $PWD + git config fetch.parallel 10 + + - name: Fetch sources + run: | + ./TheRock/build_tools/fetch_sources.py --jobs 12 + + - name: Install python deps + run: | + pip install -r TheRock/requirements.txt + pip freeze + + - name: Configure Projects + env: + amdgpu_families: ${{ env.AMDGPU_FAMILIES }} + package_version: ADHOCBUILD + extra_cmake_options: ${{ inputs.cmake_options }} + BUILD_DIR: build + run: | + python3 TheRock/build_tools/github_actions/build_configure.py + + - name: Build TheRock + run: cmake --build TheRock/build + + - name: Build therock-archives + run: cmake --build TheRock/build --target therock-archives + + - name: Report + if: ${{ !cancelled() }} + run: | + echo "Full SDK du:" + echo "------------" + du -h -d 1 TheRock/build/dist/rocm + echo "Artifact Archives:" + echo "------------------" + ls -lh TheRock/build/artifacts/*.tar.xz + echo "Artifacts:" + echo "----------" + du -h -d 1 TheRock/build/artifacts + + - name: Configure AWS Credentials + if: always() + uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 + with: + aws-region: us-east-2 + role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external + + - name: Create Logs index Files and upload logs + if: always() + run: | + python3 TheRock/build_tools/github_actions/create_log_index.py \ + --build-dir=TheRock/build \ + --amdgpu-family=${{ env.AMDGPU_FAMILIES }} + + python3 TheRock/build_tools/github_actions/upload_build_logs_to_s3.py \ + --build-dir=TheRock/build \ + --run-id ${{ github.run_id }} \ + --amdgpu-family ${{ env.AMDGPU_FAMILIES }} + + - name: Upload artifacts + run: | + python TheRock/build_tools/github_actions/upload_build_artifacts.py \ + --run-id ${{ github.run_id }} \ + --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ + --build-dir TheRock/build + + - name: Add Links to Job Summary + if: always() + run: | + python TheRock/build_tools/github_actions/upload_build_summary.py \ + --run-id ${{ github.run_id }} \ + --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ + --build-dir TheRock/build + + therock-test-linux: + name: "Test" + needs: [therock-build-linux] + uses: ./.github/workflows/therock-test-packages.yml + with: + project_to_test: "miopen" + amdgpu_families: ${{ inputs.amdgpu_families }} + test_runs_on: ${{ inputs.test_runs_on }} + platform: "linux" diff --git a/.github/workflows/therock-ci.yml b/.github/workflows/therock-ci.yml new file mode 100644 index 0000000000..18411baa09 --- /dev/null +++ b/.github/workflows/therock-ci.yml @@ -0,0 +1,50 @@ +name: TheRock CI for composable_kernel + +on: + push: + branches: + - develop + workflow_dispatch: + +permissions: + contents: read + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + therock-ci-linux: + name: TheRock CI Linux + permissions: + contents: read + id-token: write + uses: ./.github/workflows/therock-ci-linux.yml + secrets: inherit + with: + cmake_options: "-DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON -DTHEROCK_ENABLE_MIOPEN=ON -DTHEROCK_ENABLE_ALL=OFF -DTHEROCK_USE_EXTERNAL_CK=ON -DTHEROCK_CK_SOURCE_DIR=../" + amdgpu_families: "gfx94X-dcgpu" + test_runs_on: "linux-mi325-1gpu-ossci-rocm" + + therock_ci_summary: + name: TheRock CI Summary + if: always() + needs: + - therock-ci-linux + runs-on: ubuntu-24.04 + steps: + - name: Output failed jobs + run: | + echo '${{ toJson(needs) }}' + FAILED_JOBS="$(echo '${{ toJson(needs) }}' \ + | jq --raw-output \ + 'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \ + )" + if [[ "${FAILED_JOBS}" != "" ]]; then + echo "The following jobs failed: ${FAILED_JOBS}" + exit 1 + fi diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml new file mode 100644 index 0000000000..439135743c --- /dev/null +++ b/.github/workflows/therock-test-packages.yml @@ -0,0 +1,76 @@ +name: TheRock Test Packages + +on: + workflow_call: + inputs: + project_to_test: + type: string + amdgpu_families: + type: string + test_runs_on: + type: string + platform: + type: string + +permissions: + contents: read + +jobs: + configure_test_matrix: + name: "Configure test matrix" + runs-on: ubuntu-24.04 + if: ${{ inputs.test_runs_on != '' }} + outputs: + components: ${{ steps.configure.outputs.components }} + steps: + - name: "Checking out repository" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: "ROCm/TheRock" + + - name: "Configuring CI options" + env: + PLATFORM: ${{ inputs.platform }} + project_to_test: ${{ inputs.project_to_test }} + id: configure + run: python ./build_tools/github_actions/fetch_test_configurations.py + + test_components: + name: 'Test ${{ matrix.components.job_name }}' + runs-on: ${{ inputs.test_runs_on }} + needs: configure_test_matrix + # skip tests if no test matrix to run + if: ${{ needs.configure_test_matrix.outputs.components != '[]' }} + strategy: + fail-fast: false + matrix: + components: ${{ fromJSON(needs.configure_test_matrix.outputs.components) }} + defaults: + run: + shell: bash + env: + VENV_DIR: ${{ github.workspace }}/.venv + ARTIFACT_RUN_ID: "${{ github.run_id }}" + OUTPUT_ARTIFACTS_DIR: ${{ github.workspace }}/build + THEROCK_BIN_DIR: "./build/bin" + steps: + - name: Checkout Repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: "ROCm/TheRock" + + - name: Run setup test environment workflow + uses: './.github/actions/setup_test_environment' + with: + ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} + VENV_DIR: ${{ env.VENV_DIR }} + FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }} + PLATFORM: ${{ inputs.platform }} + + - name: Test + timeout-minutes: ${{ matrix.components.timeout_minutes }} + run: | + if [ "${{ inputs.PLATFORM }}" == "linux" ]; then source ${VENV_DIR}/bin/activate ; else . ${VENV_DIR}/Scripts/activate ; fi + ${{ matrix.components.test_script }} From 3f57ec3d2dc856a30ca1c652eda19e5dd4ee6041 Mon Sep 17 00:00:00 2001 From: Thrupti Raj Lakshmana Gowda Date: Tue, 12 Aug 2025 18:05:05 -0500 Subject: [PATCH 8/9] GEMM Multi D for CK Tile Engine (#2660) * Readme for GEMM Multi D * GEMM Multi D partial Progress * GEMM Multi D partial Progress! * CK Tile Engine GEMM Multi D : All Python files generated * Partial Progress * Partial Progress * Partial Progress * Partial Progress : Incorrect Result * Partial Progress : Debugging * Partial Progress : Correct Results * Partial Progress - Incorrect Results * Partial Progress - Commenting Passthrough bypass logic * Changing Passthrough to MultiplyMultiply * Correct Results! * Fix and debug the pass through feature * Sample commit * Correct Results : MultiplyMultiply * Code Cleanup * Removing Failed Instances * Working code before Unary element support * Custom Elementwise Function support and working implementation for Mul and Add * Updating README * Working for Passthrough * Review Comments : Minor Fixes * Review Comments : Minor Fixes * Readme Updated * Partial Changes after Rebase * Working Code : Changes after Rebase * Updating Jenkins file * Removing default value changed while testing * Configuration changes in config files * Tile Handler changes in GEMM Multi D Tile Engine * Tile Handler changes in GEMM Multi D Example * Change log for Gemm Multi D in CK Tile Engine * Configuration changes in config files --------- Co-authored-by: ThomasNing --- CHANGELOG.md | 1 + Jenkinsfile | 24 +- .../19_gemm_multi_d/gemm_multi_d_fp16.cpp | 90 +-- .../unary_element_wise_operation.hpp | 242 ++---- include/ck_tile/ops/reduce.hpp | 6 +- tile_engine/ops/CMakeLists.txt | 1 + tile_engine/ops/gemm_multi_d/CMakeLists.txt | 152 ++++ tile_engine/ops/gemm_multi_d/README.md | 110 +++ .../gemm_multi_d/benchmark_gemm_multi_d.cpp | 73 ++ .../gemm_multi_d/benchmark_gemm_multi_d.hpp | 218 +++++ .../configs/custom_ci_config.json | 80 ++ .../gemm_multi_d/configs/default_config.json | 84 ++ .../configs/user_provided_config.json | 81 ++ .../gemm_multi_d_codegen_utils.py | 229 ++++++ .../ops/gemm_multi_d/gemm_multi_d_config.py | 250 ++++++ .../gemm_multi_d/gemm_multi_d_host_api.hpp | 164 ++++ .../gemm_multi_d_instance_builder.py | 755 ++++++++++++++++++ .../gemm_multi_d/gemm_multi_d_profiler.hpp | 278 +++++++ 18 files changed, 2547 insertions(+), 291 deletions(-) create mode 100644 tile_engine/ops/gemm_multi_d/CMakeLists.txt create mode 100644 tile_engine/ops/gemm_multi_d/README.md create mode 100644 tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.cpp create mode 100644 tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.hpp create mode 100644 tile_engine/ops/gemm_multi_d/configs/custom_ci_config.json create mode 100644 tile_engine/ops/gemm_multi_d/configs/default_config.json create mode 100644 tile_engine/ops/gemm_multi_d/configs/user_provided_config.json create mode 100644 tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py create mode 100644 tile_engine/ops/gemm_multi_d/gemm_multi_d_config.py create mode 100644 tile_engine/ops/gemm_multi_d/gemm_multi_d_host_api.hpp create mode 100755 tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py create mode 100644 tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c942a776d..7c09271edc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added rotating buffer feature for CK_Tile GEMM. * Added int8 support for CK_TILE GEMM. * Added support for elementwise kernel. +* Added benchmarking support for tile engine GEMM Multi D. ### Optimized diff --git a/Jenkinsfile b/Jenkinsfile index 619f15d624..7955b8733a 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1176,6 +1176,8 @@ pipeline { -D GPU_TARGETS="gfx90a" \ -D GEMM_DATATYPE="fp8;fp16" \ -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -D DGEMM_MULTI_D_DATATYPE="fp16" \ + -D DGEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_fp8_rcr && \ ./bin/benchmark_gemm_fp8_rcr && \ @@ -1192,7 +1194,15 @@ pipeline { ninja -j64 benchmark_gemm_fp8_rrr && \ ./bin/benchmark_gemm_fp8_rrr && \ ninja -j64 benchmark_gemm_fp16_rrr && \ - ./bin/benchmark_gemm_fp16_rrr """ + ./bin/benchmark_gemm_fp16_rrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \ + ./bin/benchmark_gemm_multi_d_fp16_rrrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \ + ./bin/benchmark_gemm_multi_d_fp16_ccrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \ + ./bin/benchmark_gemm_multi_d_fp16_crrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \ + ./bin/benchmark_gemm_multi_d_fp16_rcrr """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1214,6 +1224,8 @@ pipeline { -D GPU_TARGETS="gfx942" \ -D GEMM_DATATYPE="fp8;fp16" \ -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -D DGEMM_MULTI_D_DATATYPE="fp16" \ + -D DGEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_fp8_rcr && \ ./bin/benchmark_gemm_fp8_rcr && \ @@ -1230,7 +1242,15 @@ pipeline { ninja -j64 benchmark_gemm_fp8_rrr && \ ./bin/benchmark_gemm_fp8_rrr && \ ninja -j64 benchmark_gemm_fp16_rrr && \ - ./bin/benchmark_gemm_fp16_rrr """ + ./bin/benchmark_gemm_fp16_rrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \ + ./bin/benchmark_gemm_multi_d_fp16_rrrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \ + ./bin/benchmark_gemm_multi_d_fp16_ccrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \ + ./bin/benchmark_gemm_multi_d_fp16_crrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \ + ./bin/benchmark_gemm_multi_d_fp16_rcrr """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp index 8971871c14..d7bf2b5c42 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp @@ -197,95 +197,7 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& } }; - if(has_hot_loop) - { -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "For compute pipeline tail number should always be Full, but have \"" << tail_num - << "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - if(tail_num == ck_tile::TailNumber::One) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - - auto check_tail = [&](auto... TNs) { - (try_run(tail_num), ...); - }; - - check_tail(ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}); - -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } -#endif - } - else - { - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "Num K loop must be larger than number of prefetech stages." - << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - } + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); return ave_time; } diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 0e385901ed..2f8cef7afd 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -262,219 +262,67 @@ struct PassThroughPack2 struct PassThrough { - template - CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; + template + using raw_t = std::remove_cv_t>; - template <> - CK_TILE_HOST_DEVICE void operator()(double& y, const double& x) const + template + CK_TILE_HOST_DEVICE void operator()(Y&& y, const X& x) const { - y = x; + /* Only do the assignment when + - y is an *l-value* and + - y is *not* const */ + if constexpr(std::is_lvalue_reference_v && !std::is_const_v>) + { + y = ck_tile::type_convert>(x); + } + /* otherwise (r-value or const) → do nothing */ } - template <> - CK_TILE_HOST_DEVICE void operator()(float& y, const double& x) const + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void { - y = type_convert(x); - } + // Suppress unused parameter warning for ds + ((void)ds, ...); - template <> - CK_TILE_HOST_DEVICE void operator()(double& y, const float& x) const - { - y = type_convert(x); + // Just assign e with c + if constexpr(std::is_same_v) + { + e = c; + } + else + { + e = ck_tile::type_convert(c); + } } +}; - template <> - CK_TILE_HOST_DEVICE void operator()(float& y, const float& x) const +struct MultiDMultiply +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void { - y = x; - } + // Start with the base value c + float result = ck_tile::type_convert(c); - template <> - CK_TILE_HOST_DEVICE void - operator()(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const - { - y = x; - } + // Multiply by each D parameter using fold expression + ((result *= ck_tile::type_convert(ds)), ...); - template <> - CK_TILE_HOST_DEVICE void operator()(ck_tile::fp16_t& y, - const float& x) const - { - y = type_convert(x); + e = ck_tile::type_convert(result); } +}; - template <> - CK_TILE_HOST_DEVICE void - operator()(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const +struct MultiDAdd +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void { - y = x; - } + // Start with the base value c + float result = ck_tile::type_convert(c); - template <> - CK_TILE_HOST_DEVICE void operator()(int32_t& y, const int32_t& x) const - { - y = x; - } + // Add by each D parameter using fold expression + ((result += ck_tile::type_convert(ds)), ...); - template <> - CK_TILE_HOST_DEVICE void operator()(ck_tile::bf16_t& y, - const float& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void operator()(float& y, - const ck_tile::bf16_t& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void - operator()(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void operator()(float& y, - const ck_tile::fp16_t& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void operator()(int8_t& y, const int8_t& x) const - { - y = x; - } - - template <> - CK_TILE_HOST_DEVICE void operator()(ck_tile::fp16_t& y, - const int8_t& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void operator()(ck_tile::bf16_t& y, - const int8_t& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void operator()(uint8_t& y, const uint8_t& x) const - { - y = x; - } - - template <> - CK_TILE_HOST_DEVICE void operator()(int8_t& y, const int32_t& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void operator()(int32_t& y, const int8_t& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void operator()(int8_t& y, const float& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void operator()(float& y, const int8_t& x) const - { - y = type_convert(x); - } - -#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - template <> - CK_TILE_HOST_DEVICE void operator()(int4_t& y, const int4_t& x) const - { - y = x; - } - template <> - CK_TILE_HOST_DEVICE void operator()(int4_t& y, const int& x) const - { - y = type_convert(x); - } -#endif - - template <> - CK_TILE_HOST_DEVICE void - operator()(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const - { - y = x; - } - - template <> - CK_TILE_HOST_DEVICE void operator()(float& y, - const ck_tile::fp8_t& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void operator()(ck_tile::fp8_t& y, - const float& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void - operator()(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void - operator()(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void - operator()(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const - { - y = x; - } - - template <> - CK_TILE_HOST_DEVICE void operator()(float& y, - const ck_tile::bf8_t& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void operator()(ck_tile::bf8_t& y, - const float& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void - operator()(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const - { - y = type_convert(x); - } - - template <> - CK_TILE_HOST_DEVICE void - operator()(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const - { - y = ck_tile::type_convert(x); + e = ck_tile::type_convert(result); } }; diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index 042e0b98c2..a6721c9305 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -6,10 +6,10 @@ #include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" -#include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/common/utils.hpp" #include "ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/tile_engine/ops/CMakeLists.txt b/tile_engine/ops/CMakeLists.txt index 0cf2c16da2..7d7002af1b 100644 --- a/tile_engine/ops/CMakeLists.txt +++ b/tile_engine/ops/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(gemm) +add_subdirectory(gemm_multi_d) \ No newline at end of file diff --git a/tile_engine/ops/gemm_multi_d/CMakeLists.txt b/tile_engine/ops/gemm_multi_d/CMakeLists.txt new file mode 100644 index 0000000000..3708dd3fee --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/CMakeLists.txt @@ -0,0 +1,152 @@ + +set(GEMM_MULTI_D_DATATYPE "fp16" CACHE STRING "List of datatypes for GEMM Multi D (semicolon-separated)") +set(GEMM_MULTI_D_LAYOUT "rcrr" CACHE STRING "List of layout for GEMM Multi D(semicolon-separated)") +set(GEMM_MULTI_D_ELEMENTWISE_FUNCTION "mul" CACHE STRING "Elementwise function") + +function(build_gemm_multi_d_for_datatype_layout datatype layout) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Comment this if-else block when using user_provided_config + if(layout STREQUAL "rcrr") + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") + else() + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/custom_ci_config.json") + endif() + + # uncomment this if you want to use user_provided_config.json + # set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json") + + # Generate kernel list + execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION} + --config_json ${json_blob} + --list_blobs + RESULT_VARIABLE ret + ) + if(NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${ret}") + endif() + + file(STRINGS "${working_path}/gemm_multi_d_instance_blobs.txt" codegen_blobs) + file(STRINGS "${working_path}/gemm_multi_d_instance_blobs_range.txt" codegen_blobs_range) + + # Generate the blobs + add_custom_command( + OUTPUT ${codegen_blobs} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py + --working_path "${working_path}" + --datatype ${datatype} + --layout ${layout} + --elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION} + --config_json "${json_blob}" + --gen_blobs + COMMENT "Generating GEMM Multi D instance sources for ${datatype} ${layout}" + ) + add_custom_target(gemm_multi_d_gen_${datatype}_${layout} DEPENDS ${codegen_blobs}) + + set(intermediate_libs) + list(LENGTH codegen_blobs codegen_blobs_len) + + foreach(blob IN LISTS codegen_blobs_range) + string(STRIP "${blob}" stripped_blob) + separate_arguments(spilit_blob UNIX_COMMAND "${stripped_blob}") + # Each line is: + list(GET spilit_blob 0 name) + list(GET spilit_blob 1 first) + list(GET spilit_blob 2 last) + math(EXPR total_files "${last} - ${first}") + if(total_files EQUAL 0) + continue() # nothing for this trait + endif() + + # Object libraries (chunked) per trait + set(sub_intermediate_libs) + set(chunk_size 3) + math(EXPR num_chunks "( ${total_files} + ${chunk_size} - 1 ) / ${chunk_size}") + math(EXPR num_chunks_minus_1 "${num_chunks} - 1") + + foreach(i RANGE 0 ${num_chunks_minus_1}) + math(EXPR start "${first} + ${i} * ${chunk_size} ") + math(EXPR end "${start} + ${chunk_size} - 1") + + set(chunk_files) + foreach(j RANGE ${start} ${end}) + if(j LESS ${last} AND j LESS ${codegen_blobs_len}) + list(GET codegen_blobs ${j} f) + list(APPEND chunk_files "${f}") + endif() + endforeach() + + #list(LENGTH chunk_files chunk_files_len) + #if(chunk_files_len AND chunk_files_len GREATER 1) + if(chunk_files) + set(sub_intermediate_lib_name "gemm_multi_d_objlib_${name}_${i}_${datatype}_${layout}") + add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files}) + list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name}) + endif() + + endforeach() + + # ------------------ Bundle the object libs into one static lib --------- + #list(LENGTH sub_intermediate_libs sub_intermediate_libs_len) + #if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1) + if(sub_intermediate_libs) + set(intermediate_lib_name "gemm_multi_d_staticlib_${name}_${datatype}_${layout}") + # Collect the $ expressions + + set(obj_exprs) + foreach(objlib IN LISTS sub_intermediate_libs) + list(APPEND obj_exprs $) + endforeach() + + add_library(${intermediate_lib_name} STATIC ${obj_exprs}) + add_dependencies(${intermediate_lib_name} gemm_multi_d_gen_${datatype}_${layout}) + #foreach(objlib IN LISTS sub_intermediate_libs) + # target_sources(${intermediate_lib_name} PRIVATE $) + #endforeach() + list(APPEND intermediate_libs ${intermediate_lib_name}) + endif() + + endforeach() + + # Interface library for instances + add_library(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE) + add_dependencies(gemm_multi_d_template_instances_${datatype}_${layout} gemm_multi_d_gen_${datatype}_${layout}) + target_link_libraries(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE ${intermediate_libs}) + target_include_directories(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE + ${CMAKE_CURRENT_LIST_DIR} + "${working_path}" + ) + set_target_properties(gemm_multi_d_template_instances_${datatype}_${layout} PROPERTIES LINKER_LANGUAGE CXX) + + # Host API interface library + add_library(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE) + target_link_libraries(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE gemm_multi_d_template_instances_${datatype}_${layout}) + target_include_directories(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE + ${CMAKE_CURRENT_LIST_DIR} + "${working_path}" + ) + + + + # Executable per datatype + set(exec_name "benchmark_gemm_multi_d_${datatype}_${layout}") + add_executable(${exec_name} benchmark_gemm_multi_d.cpp) + target_link_libraries(${exec_name} PRIVATE gemm_multi_d_host_api_${datatype}_${layout}) + target_compile_options(${exec_name} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) +endfunction() + +# Process each datatype in isolation +foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE) + foreach(l IN LISTS GEMM_MULTI_D_LAYOUT) + build_gemm_multi_d_for_datatype_layout(${dt} ${l}) + endforeach() +endforeach() diff --git a/tile_engine/ops/gemm_multi_d/README.md b/tile_engine/ops/gemm_multi_d/README.md new file mode 100644 index 0000000000..369553b121 --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/README.md @@ -0,0 +1,110 @@ + +CK Tile Engine for GEMM Multi D is used to generate and run GEMM kernels with different combinations of BlockTile sizes, WarpTile sizes, WarpTile mapping for all valid pipelines, schedulers and epilogues while able to give custom datatype and Layout selections + +# Kernel Configurations + +# User Specific +Users can specify custom kernel configurations such as tile size, warp size, padding, pipeline, scheduler, and epilogue in the config file. This allows building only for selected configurations, significantly reducing build time. +For reference please see `./configs/user_provided_config.json`. + +# Default +The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json` + +If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark. + +## Build Instructions +``` bash +# in the root of composable kernel create build directory +mkdir build && cd build +# build composable kernel +# replace [Arch] with the appropriate architecture or leave blank and +# replace [Datatype] in comma separated datatypes string (possible datatypes are [fp16]) +# replace [Layout1;Layout2;...] in comma separated datatypes string (possible layouts are [rcr, rrr, crr, ccr]) +# replace "mul" with either of mul,add,passthrough for Elementwise function as Multiply, Add or Passthrough respectively. If this is not specified it is considered as mul by default. +sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_MULTI_D_DATATYPE="[Datatype]" -DGEMM_MULTI_D_LAYOUT="[Layout1;Layout2]" -DGEMM_MULTI_D_ELEMENTWISE_FUNCTION="mul" +# generate different executable for each passed datatype +make benchmark_gemm_multi_d_[Datatype]_[Layout1] -j +make benchmark_gemm_multi_d_[Datatype]_[Layout2] -j +``` +`benchmark_gemm_multi_d_[Datatype]_[Layout]` will be located in the `./bin/` directory. + +`benchmark_gemm_multi_d_[Datatype]_[Layout]` must be rebuilt everytime if configuration file is modified. + +``` bash +rm -rf tile_engine/ && make benchmark_gemm_multi_d_[Datatype]_[Layout] -j # rebuild +``` + +## For eaxmple build for gfx942 for datatype with rcr layout +``` bash +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_MULTI_D_DATATYPE="fp16" -DGEMM_MULTI_D_LAYOUT="rcrr" +make benchmark_gemm_multi_d_fp16_rcrr -j + +## benchmark_gemm inputs +``` + -m The value for m dimension. Default is 3840. + -n The value for n dimension. Default is 4096. + -k The value for k dimension. Default is 2048. + -stride_a The stride value for tensor A. Default is 0. + -stride_b The stride value for tensor B. Default is 0. + -stride_ds The stride value for tensor Ds. Default is 0. + -stride_e The stride value for tensor E. Default is 0. + -split_k The split value for k dimension. Default is 1. + -verify The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 for validation on GPU. Default is 1, validation on CPU, as validation on GPU is not supported. + -log Wether output kernel instance information or not. Possible values are true or false. Default is false. + -warmup The number of iterations before benchmark the kernel. Default is 50. + -repeat The number of iterations to benchmark the kernel. Default is 100. + -timer Whether if the timer is gpu timer or not. Possible values are false or true. Default is true. + -init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random. + -flush_cache To flush cache, possible values are true or false. Default is false. + -rotating_count Number of iterations to rotate the cache. Default is 5. + -metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency. + -csv_filename The filename of benchmark result. Default is gemm_multi_d_kernel. + -pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3. + -scheduler The type of scheduler. Possible values are intrawave. Default is intrawave. + -epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle. + -pad_m Whether pad or not in m direction. Possible values are true or false. Default is false. + -pad_n Whether pad or not in n direction. Possible values are true or false. Default is false. + -pad_k Whether pad or not in k direction. Possible values are true or false. Default is false. + +Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in user_provided_config.json +``` +Note: In `./configs/user_provided_config.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above. + +## Example + +The following JSON file specifies parameters used to generate and build GEMM kernels across all possible combinations of pipelines, schedulers, epilogues with different tile and warp sizes. + +```json +{ + /// other parameters /// + + "tile_m": { + "values": [256] + }, + "tile_n": { + "values": [256] + }, + "tile_k": { + "values": [64, 32] + }, + + /// other parameters /// + + "pipeline": { + "values": ["compv3", "compv4", "mem"] + }, + "scheduler": { + "values": ["intrawave", "interwave"] + }, + "epilogue": { + "values": ["cshuffle"] + } +} +``` + +At runtime, a specific subset of the generated kernels can be selected using command-line arguments. +``` bash +./bin/benchmark_gemm_multi_d_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=cshuffle +``` +The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and cshuffle epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings. diff --git a/tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.cpp b/tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.cpp new file mode 100644 index 0000000000..764a295809 --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "benchmark_gemm_multi_d.hpp" +#include "gemm_multi_d_profiler.hpp" + +void benchmark_gemm_multi_d(const ck_tile::ArgParser& arg_parser) +{ + GemmMultiDProblem gemm_multi_d_problem{arg_parser.get_int("split_k"), + arg_parser.get_int("m"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("stride_a"), + arg_parser.get_int("stride_b"), + arg_parser.get_int("stride_ds"), + arg_parser.get_int("stride_ds"), + arg_parser.get_int("stride_e"), + DataTypeTraits::name, + DataTypeTraits::name, + DataTypeTraits::name, + DataTypeTraits::name, + DataTypeTraits::name, + DataTypeTraits::name, + ALayout::name, + BLayout::name, + D0Layout::name, + D1Layout::name, + ELayout::name}; + + Setting setting{arg_parser.get_int("warmup"), + arg_parser.get_int("repeat"), + arg_parser.get_bool("timer"), + arg_parser.get_int("verify"), + arg_parser.get_int("init"), + arg_parser.get_bool("log"), + arg_parser.get_str("csv_filename"), + arg_parser.get_bool("flush_cache"), + arg_parser.get_int("rotating_count")}; + + auto& profiler = GemmMultiDProfiler::instance(setting); + + try + { + auto kernel_func = get_kernel_func_by_trait(arg_parser); + profiler.benchmark(gemm_multi_d_problem, kernel_func); + profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); + } + catch(const std::exception& e) + { + std::cerr << "Benchmark failed: " << e.what() << std::endl; + } +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, parser] = create_args(argc, argv); + if(!result) + return EXIT_FAILURE; + benchmark_gemm_multi_d(parser); + return 0; + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + return EXIT_FAILURE; + } +} diff --git a/tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.hpp b/tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.hpp new file mode 100644 index 0000000000..f52d69e374 --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.hpp @@ -0,0 +1,218 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "gemm_multi_d_host_api.hpp" + +struct GemmMultiDProblem +{ + int split_k_; + int m_, n_, k_; + int stride_a_, stride_b_, stride_d0_, stride_d1_, stride_e_; + + std::string dtype_a_, dtype_b_, dtype_d0_, dtype_d1_, dtype_acc_, dtype_e_; + std::string layout_a_, layout_b_, layout_d0_, layout_d1_, layout_e_; + + friend std::ostream& operator<<(std::ostream& os, const GemmMultiDProblem& problem) + { + os << "{\n" + << " \"split_k\":" << problem.split_k_ << ",\n" + << " \"m\":" << problem.m_ << ",\n" + << " \"n\":" << problem.n_ << ",\n" + << " \"k\":" << problem.k_ << ",\n" + << " \"stride_a\":" << problem.stride_a_ << ",\n" + << " \"stride_b\":" << problem.stride_b_ << ",\n" + << " \"stride_d0\":" << problem.stride_d0_ << ",\n" + << " \"stride_d1\":" << problem.stride_d1_ << ",\n" + << " \"stride_e\":" << problem.stride_e_ << ",\n" + << " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n" + << " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n" + << " \"dtype_d0\":\"" << problem.dtype_d0_ << "\",\n" + << " \"dtype_d1\":\"" << problem.dtype_d1_ << "\",\n" + << " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n" + << " \"dtype_e\":\"" << problem.dtype_e_ << "\",\n" + << " \"layout_a\":\"" << problem.layout_a_ << "\",\n" + << " \"layout_b\":\"" << problem.layout_b_ << "\",\n" + << " \"layout_d0\":\"" << problem.layout_d0_ << "\",\n" + << " \"layout_d1\":\"" << problem.layout_d1_ << "\",\n" + << " \"layout_e\":\"" << problem.layout_e_ << "\"\n" + << "}"; + return os; + } +}; + +struct Setting +{ + int n_warmup_; + int n_repeat_; + bool is_gpu_timer_; + int verify_; + int init_method_; + bool log_; + std::string csv_filename_; + bool flush_cache_; + int rotating_count_; +}; + +// @brief Function to get the kernel output with reference implementation on CPU +void gemm_multi_d_host_reference(int verify, + ck_tile::HostTensor& a_m_k, + ck_tile::HostTensor& b_k_n, + ck_tile::HostTensor& d0_m_n, + ck_tile::HostTensor& d1_m_n, + ck_tile::HostTensor& e_m_n_host_result) +{ + if(verify > 0) + { + // Currently supporting on CPU verification for Gemm Multi D + // e_m_n_host_result.SetZero(); + ck_tile::reference_gemm_multiple_d( + a_m_k, b_k_n, {d0_m_n, d1_m_n}, e_m_n_host_result); + } +} + +enum class Metric +{ + LATENCY = 0, + TFLOPS = 1, + BANDWIDTH = 2 +}; + +inline constexpr auto get_metric_name(Metric m) +{ + switch(m) + { + case Metric::LATENCY: return "latency"; + case Metric::TFLOPS: return "tflops"; + case Metric::BANDWIDTH: return "bandwidth"; + default: throw std::invalid_argument("Unsupported metric type"); + } +} + +struct PerformanceResult +{ + double latency_; + double tflops_; + double bandwidth_; + + static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m) + { + switch(m) + { + case Metric::LATENCY: return a.latency_ < b.latency_; + case Metric::TFLOPS: return a.tflops_ > b.tflops_; + case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_; + default: throw std::invalid_argument("Unsupported metric type"); + } + } + + friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result) + { + os << "{\n" + << " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_ + << ",\n" + << " \"tflops(TFlops)\": " << result.tflops_ << ",\n" + << " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n" + << "}"; + return os; + } +}; + +struct KernelInstance +{ + std::string name_; + GemmMultiDProblem problem_; + PerformanceResult perf_result_; + + static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m) + { + return PerformanceResult::compare(a.perf_result_, b.perf_result_, m); + } + + friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj) + { + os << "{\n" + << " \"name\": \"" << "{\n" + << obj.name_ << "\n}" << "\",\n" + << " \"problem\": \"" << obj.problem_ << "\",\n" + << " \"perf_result\": " << obj.perf_result_ << "\n" + << "}"; + return os; + } +}; + +inline std::string get_rocm_version() +{ + std::ifstream version_file("/opt/rocm/.info/version"); + if(version_file.is_open()) + { + std::string version; + std::getline(version_file, version); + return version; + } + return "Unknown"; +} + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +/// @brief Function to compare the results of the device and host computations +bool compare(std::string instanceName, + ck_tile::index_t K, + ck_tile::HostTensor& e_m_n_dev_result, + ck_tile::HostTensor& e_m_n_host_result) +{ + const float max_accumulated_value = + *std::max_element(e_m_n_host_result.mData.begin(), e_m_n_host_result.mData.end()); + + const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value); + + bool pass = ck_tile::check_err(e_m_n_dev_result, + e_m_n_host_result, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "For " << instanceName << " Relative error threshold is " + << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is " + << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; + + return pass; +} diff --git a/tile_engine/ops/gemm_multi_d/configs/custom_ci_config.json b/tile_engine/ops/gemm_multi_d/configs/custom_ci_config.json new file mode 100644 index 0000000000..cd638d9af0 --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/configs/custom_ci_config.json @@ -0,0 +1,80 @@ +{ + "tile_config": { + "tile_m": { + "values": [ + 256 ] + }, + "tile_n": { + "values": [ + 128 + ] + }, + "tile_k": { + "values": [ + 32 + ] + }, + "warp_m": { + "values": [ + 2 + ] + }, + "warp_n": { + "values": [ + 2 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 16 + ] + }, + "warp_tile_n": { + "values": [ + 16 + ] + }, + "warp_tile_k": { + "values": [ + 16 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv3" + ] + }, + "scheduler": { + "values": [ + "intrawave" + ] + }, + "epilogue": { + "values": [ + "cshuffle" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + } + } +} \ No newline at end of file diff --git a/tile_engine/ops/gemm_multi_d/configs/default_config.json b/tile_engine/ops/gemm_multi_d/configs/default_config.json new file mode 100644 index 0000000000..6d1afa4425 --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/configs/default_config.json @@ -0,0 +1,84 @@ +{ + "tile_config": { + "tile_m": { + "values": [ + 256 + ] + }, + "tile_n": { + "values": [ + 128 + ] + }, + "tile_k": { + "values": [ + 32 + ] + }, + "warp_m": { + "values": [ + 2 + ] + }, + "warp_n": { + "values": [ + 2 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 16 + ] + }, + "warp_tile_n": { + "values": [ + 16 + ] + }, + "warp_tile_k": { + "values": [ + 16 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv3", + "compv4", + "mem" + ] + }, + "scheduler": { + "values": [ + "intrawave", + "interwave" + ] + }, + "epilogue": { + "values": [ + "cshuffle" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + } + } +} \ No newline at end of file diff --git a/tile_engine/ops/gemm_multi_d/configs/user_provided_config.json b/tile_engine/ops/gemm_multi_d/configs/user_provided_config.json new file mode 100644 index 0000000000..243d858fe5 --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/configs/user_provided_config.json @@ -0,0 +1,81 @@ +{ + "tile_config": { + "tile_m": { + "values": [ + 256 + ] + }, + "tile_n": { + "values": [ + 256 + ] + }, + "tile_k": { + "values": [ + 64 + ] + }, + "warp_m": { + "values": [ + 2 + ] + }, + "warp_n": { + "values": [ + 2 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 32 + ] + }, + "warp_tile_n": { + "values": [ + 32 + ] + }, + "warp_tile_k": { + "values": [ + 16 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv3" + ] + }, + "scheduler": { + "values": [ + "intrawave" + ] + }, + "epilogue": { + "values": [ + "cshuffle" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + } + } +} \ No newline at end of file diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py new file mode 100644 index 0000000000..7d3629819d --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# -*- coding: utf-8 -*- + +""" +Mappings and utility functions for kernel code generation. +""" + +import subprocess +import re +from functools import lru_cache + +DATA_TYPE_MAP = { + "fp32": "float", + "fp16": "ck_tile::half_t", + "bf16": "ck_tile::bf16_t", + "int8": "ck_tile::int8_t", + "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", + "int4": "ck_tile::pk_int4_t", + "int32": "ck_tile::int32_t", +} + +LAYOUT_MAP = { + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor", +} + + +# TODO THIS IS NOT SUPPORTED FOR MULTI D AS OF NOW +# DEFAULT_EPILOGUE = """ +# using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< +# ck_tile::DefaultGemm2DEpilogueProblem>; +# """ + +CSHUFFLE_EPILOGUE = """ + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; +""" + +PIPELINE_MAP = { + "mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"], + "compv3": [ + "ck_tile::BaseGemmPipelineAgBgCrCompV3", + "ck_tile::GemmPipelineAgBgCrCompV3", + ], + "compv4": [ + "ck_tile::BaseGemmPipelineAgBgCrCompV4", + "ck_tile::GemmPipelineAgBgCrCompV4", + ], +} + +SCHEDULER_MAP = { + "interwave": "ck_tile::GemmPipelineScheduler::Interwave", + "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", +} + +# EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE} + +EPILOGUE_MAP = {"cshuffle": CSHUFFLE_EPILOGUE} + + +def BOOL_MAP(b_): + return {True: "true", False: "false"}[bool(b_)] + + +# Can add some more supported combinations +warp_tile_supported_combinations = { + "gfx90a": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], + }, + "gfx942": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + "gfx950": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_bf8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 64], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + }, +} + +# Remove some unsupported combinations +trait_unsupported_combinations = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), +} + + +ELEMENT_SIZE_MAP = { + "fp16": 2, + "bf16": 2, + "int8": 1, + "fp8": 1, + "bf8": 1, + "int4": 0.5, + "int32": 4, +} + + +def element_size(data_type: str) -> float: + """Calculate the size (in bytes) of a single element for given data type.""" + data_type = data_type.lower() + if data_type not in ELEMENT_SIZE_MAP: + raise ValueError(f"Unsupported data type: {data_type}") + return ELEMENT_SIZE_MAP[data_type] + + +GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)") + + +@lru_cache(maxsize=1) +def get_gpu_name_by_id(gpu_id: int = 0) -> str: + """Retrieve GPU name (e.g. gfx90a) by device ID""" + try: + output = subprocess.check_output( + ["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5 + ) + if matches := GPU_NAME_PATTERN.finditer(output): + gpu_list = [m.group(1) for m in matches] + return gpu_list[gpu_id] if gpu_id < len(gpu_list) else "" + + return "" + + except subprocess.CalledProcessError as e: + print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}") + except FileNotFoundError: + print("ROCm tools not installed (requires rocminfo)") + except subprocess.TimeoutExpired: + print("GPU query timeout (5s)") + except Exception as e: + print(f"GPU detection error: {str(e)}") + + return "" diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_config.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_config.py new file mode 100644 index 0000000000..e5a879158f --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_config.py @@ -0,0 +1,250 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# -*- coding: utf-8 -*- + +""" +Handles loading, parsing, and validation of JSON and Argument configuration parameters. +""" + +from pathlib import Path +from dataclasses import dataclass +from typing import List, Optional, Union, Type +import json + + +@dataclass +class EnumConfigParam: + """Represents an enumeration-type configuration parameter""" + + values: List[Union[int, str, bool]] + + +@dataclass +class RangeConfigParam: + """Represents a numeric range-type configuration parameter""" + + min: int + max: int + step: int + exclude: Optional[List[int]] + + def generate_candidates(self) -> List[int]: + """Generates valid candidates after applying range constraints""" + + if self.min > self.max: + raise ValueError(f"Invalid range: min({self.min}) > max({self.max})") + if self.step <= 0: + raise ValueError(f"Step must be positive, got {self.step}") + + candidates = list(range(self.min, self.max + 1, self.step)) + + if hasattr(self, "exclude") and self.exclude: + if not isinstance(self.exclude, list): + raise TypeError("exclude must be list type") + exclude_set = set(self.exclude) + candidates = [x for x in candidates if x not in exclude_set] + + if not candidates: + raise ValueError( + f"No valid candidates for range [{self.min}-{self.max}] " + f"with step {self.step} and excludes {self.exclude}" + ) + + return candidates + + +@dataclass +class DataType: + """Configuration class for data type parameter.""" + + a_datatype: str + b_datatype: str + e_datatype: str + d0_datatype: str + d1_datatype: str + ds_datatype: List[str] + + +@dataclass +class Layout: + """Configuration class for Layout parameter.""" + + a_layout: str + b_layout: str + e_layout: str + d0_layout: str + d1_layout: str + ds_layout: List[str] + + +@dataclass +class ArgumentConfig: + """Configuration class for Argument parameter.""" + + datatypes: DataType + layouts: Layout + function_name: str + + @classmethod + def from_args( + cls: Type["ArgumentConfig"], + datatype: str, + layout: str, + elementwise_function: str, + ) -> "ArgumentConfig": + """configuration loader with validation controls""" + + datatypes = DataType( + a_datatype=datatype, + b_datatype=datatype, + e_datatype=datatype, + d0_datatype=datatype, + d1_datatype=datatype, + ds_datatype=[datatype, datatype], + ) + + layout_parts = layout.lower() + assert len(layout_parts) == 4, ( + f"Invalid layout string: {layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)" + ) + assert layout_parts[0] in ("r", "c"), ( + f"Invalid matrix_a layout: {layout_parts[0]} (must be 'r' for row major or or 'c' for column major)" + ) + assert layout_parts[1] in ("r", "c"), ( + f"Invalid matrix_b layout: {layout_parts[1]} (must be 'r' for row major or or 'c' for column major)" + ) + assert layout_parts[2] == "r", ( + f"Invalid matrix_e layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)" + ) + assert layout_parts[3] == "r", ( + f"Invalid D dimension layout: {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)" + ) + + layouts = Layout( + a_layout=layout[0], + b_layout=layout[1], + e_layout=layout[2], + d0_layout=layout[3], + d1_layout=layout[3], + ds_layout=[layout[3], layout[3]], + ) + # Elementwise function name validation + valid_functions = ["mul", "add", "passthrough"] + if elementwise_function not in valid_functions: + raise ValueError( + f"Invalid elementwise function: {elementwise_function}. " + f"Valid options are: {', '.join(valid_functions)}" + ) + + # Set the function name based on the elementwise function + if elementwise_function == "mul": + function_name = "MultiDMultiply" + elif elementwise_function == "add": + function_name = "MultiDAdd" + elif elementwise_function == "passthrough": + function_name = "PassThrough" # TODO Change this + + return cls(datatypes=datatypes, layouts=layouts, function_name=function_name) + + +@dataclass +class TileConfig: + """Configuration class for tile parameter.""" + + tile_m: Union[EnumConfigParam, RangeConfigParam] + tile_n: Union[EnumConfigParam, RangeConfigParam] + tile_k: Union[EnumConfigParam, RangeConfigParam] + + warp_m: Union[EnumConfigParam, RangeConfigParam] + warp_n: Union[EnumConfigParam, RangeConfigParam] + warp_k: Union[EnumConfigParam, RangeConfigParam] + + warp_tile_m: Union[EnumConfigParam, RangeConfigParam] + warp_tile_n: Union[EnumConfigParam, RangeConfigParam] + warp_tile_k: Union[EnumConfigParam, RangeConfigParam] + + +@dataclass +class TraitConfig: + """Configuration class for kernel traits.""" + + pipeline: EnumConfigParam + scheduler: EnumConfigParam + epilogue: EnumConfigParam + pad_m: EnumConfigParam + pad_n: EnumConfigParam + pad_k: EnumConfigParam + + +@dataclass +class JsonConfig: + """Configuration class for JSON parameter.""" + + tile_config: TileConfig + trait_config: TraitConfig + + @classmethod + def from_json(cls: Type["JsonConfig"], filepath: str) -> "JsonConfig": + """JSON configuration loader with validation controls""" + config_path = Path(filepath) + + try: + if not config_path.exists(): + raise FileNotFoundError(f"Config file {filepath} not found") + + with config_path.open("r") as f: + config_dict = json.load(f) + + # Parse tile config + def create_param(param_dict): + if "values" in param_dict: + return EnumConfigParam(values=param_dict["values"]) + else: + return RangeConfigParam( + min=param_dict["min"], + max=param_dict["max"], + step=param_dict["step"], + exclude=param_dict.get("exclude", []), + ) + + tile_config = TileConfig( + tile_m=create_param(config_dict["tile_config"]["tile_m"]), + tile_n=create_param(config_dict["tile_config"]["tile_n"]), + tile_k=create_param(config_dict["tile_config"]["tile_k"]), + warp_m=create_param(config_dict["tile_config"]["warp_m"]), + warp_n=create_param(config_dict["tile_config"]["warp_n"]), + warp_k=create_param(config_dict["tile_config"]["warp_k"]), + warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]), + warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]), + warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]), + ) + + # Parse trait config + trait_config = TraitConfig( + pipeline=EnumConfigParam( + values=config_dict["trait_config"]["pipeline"]["values"] + ), + scheduler=EnumConfigParam( + values=config_dict["trait_config"]["scheduler"]["values"] + ), + epilogue=EnumConfigParam( + values=config_dict["trait_config"]["epilogue"]["values"] + ), + pad_m=EnumConfigParam( + values=config_dict["trait_config"]["pad_m"]["values"] + ), + pad_n=EnumConfigParam( + values=config_dict["trait_config"]["pad_n"]["values"] + ), + pad_k=EnumConfigParam( + values=config_dict["trait_config"]["pad_k"]["values"] + ), + ) + + return cls(tile_config=tile_config, trait_config=trait_config) + + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON format: {str(e)}") + except KeyError as e: + raise KeyError(f"Missing required configuration field: {str(e)}") diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_host_api.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_host_api.hpp new file mode 100644 index 0000000000..41fddf30aa --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_host_api.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "gemm_multi_d_dispatcher.hpp" +#include "gemm_multi_d_common.hpp" + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +inline auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.") + .insert("n", "4096", "The value for n dimension. Default is 4096.") + .insert("k", "2048", "The value for k dimension. Default is 2048.") + .insert("stride_a", "0", "The stride value for tensor A. Default is 0.") + .insert("stride_b", "0", "The stride value for tensor B. Default is 0.") + .insert("stride_ds", "0", "The stride value for tensor Ds Default is 0.") + .insert("stride_e", "0", "The stride value for tensor E Default is 0.") + .insert("split_k", "1", "The split value for k dimension. Default is 1.") + .insert("verify", + "1", + "The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 " + "for validation on GPU. Default is 1, validation on CPU, as validation on GPU is " + "not supported.") + .insert("log", + "false", + "Wether output kernel instance information or not. Possible values are true or " + "false. Default is false") + .insert("warmup", + "50", + "The number of iterations before benchmarking the kernel. Default is 50.") + .insert("repeat", + "100", + "The number of iterations for benchmarking the kernel. Default is 100.") + .insert("timer", + "true", + "Indicates whether the timer is a GPU timer. Possible values are true or false. " + "Default is true.") + .insert("init", + "0", + "The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 " + "for constant(1). Default is 0, random.") + .insert("flush_cache", + "false", + "To flush cache, possible values are true or false. " + "Default is false.") + .insert("rotating_count", "5", "number of iterations to rotate the cache. default is 5.") + .insert("metric", + "0", + "Metric with which to measure kernel performance. Set to 0 for latency, 1 for " + "tflops, or 2 for bandwidth. Default is 0, latency.") + .insert("csv_filename", + "gemm_multi_d_kernel", + "The filename of benchmark result. Default is set to gemm_multi_d_kernel.") + .insert( + "pipeline", + "compv3", + "The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.") + .insert("scheduler", + "intrawave", + "The type of pipeline. Possible values are compv3, compv4 or mem. Default is " + "compv3.") + .insert( + "epilogue", + "cshuffle", + "The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.") + .insert("pad_m", + "false", + "Whether pad or not in m direction. Possible values are true or false. Default is " + "false.") + .insert("pad_n", + "false", + "Whether pad or not in n direction. Possible values are true or false. Default is " + "false.") + .insert("pad_k", + "false", + "Whether pad or not in k direction. Possible values are true or false. Default is " + "false."); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser) +{ + KernelTraits trait; + trait.pipeline = arg_parser.get_str("pipeline"); + trait.scheduler = arg_parser.get_str("scheduler"); + trait.epilogue = arg_parser.get_str("epilogue"); + trait.pad_m = arg_parser.get_bool("pad_m"); + trait.pad_n = arg_parser.get_bool("pad_n"); + trait.pad_k = arg_parser.get_bool("pad_k"); + + return GemmMultiDDispatcher::dispatch(trait); +} diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py new file mode 100755 index 0000000000..6e65f6bf75 --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py @@ -0,0 +1,755 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# -*- coding: utf-8 -*- + +""" +generate kernel instances to speed up compilation +""" + +import argparse +import itertools +from pathlib import Path +from typing import List, Optional +from gemm_multi_d_config import JsonConfig, ArgumentConfig, RangeConfigParam +from gemm_multi_d_codegen_utils import ( + DATA_TYPE_MAP, + LAYOUT_MAP, + PIPELINE_MAP, + SCHEDULER_MAP, + EPILOGUE_MAP, + BOOL_MAP, + warp_tile_supported_combinations, + trait_unsupported_combinations, + element_size, + get_gpu_name_by_id, +) +import logging + +logging.basicConfig(level=logging.INFO) + + +class GemmMultiDCodeGenerator: + """GEMM (General Matrix Multiplication) Multi D code generator.""" + + def __init__( + self, + args: argparse.Namespace, + user_provided_config: Optional[JsonConfig] = None, + ): + self.output_dir = Path(args.working_path) + self.output_dir.mkdir(parents=True, exist_ok=True) + + if user_provided_config is not None: + self.config = user_provided_config + else: + config_path = ( + Path(__file__).resolve().parent / "configs" / "default_config.json" + ) + self.config = JsonConfig.from_json(config_path) + + self.args = ArgumentConfig.from_args( + args.datatype, args.layout, args.elementwise_function + ) + + self.valid_trait_names: List[str] = [] + self.valid_trait_tile_combinations: map[str, list[tuple[int]]] = {} + + def list_all_trait_names(self): + """List all possible kernel trait names into file.""" + w_p = Path(self.output_dir) + file_path = w_p / "gemm_multi_d_instance_blobs.txt" + self._generate_all_traits() + self._get_valid_trait_tile_combinations() + file_range_map = {} + # Write all file paths to the header file + files_listed = 0 + with file_path.open("w") as f: + # Core files + core_files = [ + "gemm_multi_d_common.hpp", + "gemm_multi_d_instances.hpp", + "gemm_multi_d_dispatcher.hpp", + ] + for core_file in core_files: + f.write(str(w_p / core_file) + "\n") + files_listed += 1 + + # Trait header files + for trait in self.valid_trait_names: + trait_file = f"gemm_multi_d_{trait}.hpp" + f.write(str(w_p / trait_file) + "\n") + files_listed += 1 + file_name = set() + # Instance source files + for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): + start_idx = files_listed + for tile in tile_valid_params: + for ( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + _, + _, + _, + ) in tile: + instance_name = f"gemm_multi_d_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp" + + if instance_name not in file_name: + file_name.add(instance_name) + f.write(str(w_p / instance_name) + "\n") + files_listed += 1 + + file_range_map[trait] = (start_idx, files_listed) + + file_path = w_p / "gemm_multi_d_instance_blobs_range.txt" + with file_path.open("w") as f: + for name, ranges in file_range_map.items(): + start, last = ranges + f.write(name + " " + f"{start}" + " " + f"{last}" + "\n") + + def _generate_all_traits(self): + """Generate all possible kernel traits names.""" + params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"] + + # Generate all unique_combinations + _unique = set( + itertools.product( + *[getattr(self.config.trait_config, param).values for param in params] + ) + ) + + for combo in _unique: + pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo + current_combination = (pipeline, epilogue, scheduler) + + if current_combination not in trait_unsupported_combinations: + trait_name = ( + f"{pipeline}_{epilogue}_{scheduler}_" + f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}" + ) + self.valid_trait_names.append(trait_name) + else: + logging.debug(f"Invalid combination: {pipeline}-{epilogue}-{scheduler}") + + def _get_valid_trait_tile_combinations(self): + def get_tile_value(tile_param): + return ( + tile_param.generate_candidates() + if isinstance(tile_param, RangeConfigParam) + else tile_param.values + ) + + tile_group = list( + itertools.product( + get_tile_value(self.config.tile_config.tile_m), + get_tile_value(self.config.tile_config.tile_n), + get_tile_value(self.config.tile_config.tile_k), + ) + ) + + warp_group = list( + itertools.product( + get_tile_value(self.config.tile_config.warp_m), + get_tile_value(self.config.tile_config.warp_n), + get_tile_value(self.config.tile_config.warp_k), + ) + ) + + warp_tile_group = list( + itertools.product( + get_tile_value(self.config.tile_config.warp_tile_m), + get_tile_value(self.config.tile_config.warp_tile_n), + get_tile_value(self.config.tile_config.warp_tile_k), + ) + ) + + tile_params = { + t + w + wt for t in tile_group for w in warp_group for wt in warp_tile_group + } + + for trait in self.valid_trait_names: + tile_valid_params = [ + tile for tile in tile_params if self.is_tile_valid(tile, trait) + ] + + if trait not in self.valid_trait_tile_combinations: + self.valid_trait_tile_combinations[trait] = [] + self.valid_trait_tile_combinations[trait].append(tile_valid_params) + + def is_tile_valid(self, tile: tuple, trait: str) -> bool: + """Check if the tile configuration is valid for the given trait.""" + ( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) = tile + pipeline, *_ = trait.split("_") + + # Parameter validity check + invalid_params = [] + if (warp_m, warp_n, warp_k) not in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]: + invalid_params.append( + f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})" + ) + if (warp_m * warp_tile_m) == 0: + invalid_params.append(f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})") + if (warp_n * warp_tile_n) == 0: + invalid_params.append(f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})") + if (warp_k * warp_tile_k) == 0: + invalid_params.append(f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})") + + if invalid_params: + logging.debug( + f"Trait: [{trait}], Invalid warp configuration: {', '.join(invalid_params)}. " + f"Parameter combination: warp=({warp_m},{warp_n},{warp_k}), " + f"warp_tile=({warp_tile_m},{warp_tile_n},{warp_tile_k})" + ) + return False + # Dimension alignment check + alignment_issues = [] + if tile_m % (warp_m * warp_tile_m) != 0: + alignment_issues.append( + f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" + ) + if tile_n % (warp_n * warp_tile_n) != 0: + alignment_issues.append( + f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" + ) + if tile_k % (warp_k * warp_tile_k) != 0: + alignment_issues.append( + f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" + ) + + if alignment_issues: + logging.debug( + f"Trait: [{trait}], Dimension alignment failed: {', '.join(alignment_issues)}. " + f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " + f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + ) + return False + + # LDS capacity verification + matrix_a_size = (tile_m * tile_k) * element_size(self.args.datatypes.a_datatype) + + matrix_b_size = (tile_n * tile_k) * element_size(self.args.datatypes.b_datatype) + + total_tile_in_lds = matrix_a_size + matrix_b_size + + max_tile_size = 2**15 if pipeline == "compv4" else 2**16 + + if total_tile_in_lds > max_tile_size: + logging.debug( + f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " + f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" + f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" + f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B" + ) + return False + + # Warp combination validation + warp_tile_key = f"{self.args.datatypes.a_datatype}_{self.args.datatypes.b_datatype}_{self.args.datatypes.e_datatype}" + + current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] + + gpu_name = get_gpu_name_by_id(0) + + gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {}) + if not gpu_warp_tile_key: + logging.debug( + f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check." + ) + return False + + allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, []) + if not allowed_combinations: + logging.debug( + f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check." + ) + return False + + if current_combination not in allowed_combinations: + logging.debug( + f"Trait: [{trait}], Invalid warp combination: {current_combination} not in allowed list. " + f"Valid combinations for data type '{warp_tile_key}': {allowed_combinations}" + ) + return False + + return True + + def generate_all_instance_files(self): + """Generate all kernel instances files.""" + self._generate_common_header_file() + self._generate_all_trait_files() + self._generate_dispatcher_file() + + def _generate_common_header_file(self): + """Generate common header file with datatypes and layout.""" + + acc_type = "float" # As we are currently supporting only fp16 + + content = f""" +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +// Data types +using ADataType = {DATA_TYPE_MAP[self.args.datatypes.a_datatype]}; +using BDataType = {DATA_TYPE_MAP[self.args.datatypes.b_datatype]}; +using AccDataType = {acc_type}; +using D0DataType = {DATA_TYPE_MAP[self.args.datatypes.d0_datatype]}; +using D1DataType = {DATA_TYPE_MAP[self.args.datatypes.d1_datatype]}; +using DsDataType = ck_tile::tuple; +using EDataType = {DATA_TYPE_MAP[self.args.datatypes.e_datatype]}; + + +// Layout configurations +using ALayout = {LAYOUT_MAP[self.args.layouts.a_layout]}; +using BLayout = {LAYOUT_MAP[self.args.layouts.b_layout]}; +using D0Layout = {LAYOUT_MAP[self.args.layouts.d0_layout]}; +using D1Layout = {LAYOUT_MAP[self.args.layouts.d1_layout]}; +using DsLayout = ck_tile::tuple; +using ELayout = {LAYOUT_MAP[self.args.layouts.e_layout]}; + +// Element-wise function for D +using ElementWiseFn = ck_tile::element_wise::{self.args.function_name}; + +""" + + (self.output_dir / "gemm_multi_d_common.hpp").write_text(content) + + def _generate_all_trait_files(self): + """Generate all kernel traits into files.""" + if not self.valid_trait_names: + self._generate_all_traits() + self._get_valid_trait_tile_combinations() + for trait in self.valid_trait_names: + self._generate_trait_file(trait) + self._generate_instantiation_source_files() + self._generate_common_instance_header_file() + + def _generate_trait_file(self, trait: str): + """Generate a trait with all tile/warp combinations.""" + pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = trait.split("_") + filename = f"gemm_multi_d_{trait}.hpp" + + content = f""" +#pragma once + +#include "gemm_multi_d_common.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/host.hpp" + +namespace {trait} {{ +""" + # Add template struct with configuration + content += self._generate_kernel_struct( + pipeline, epilogue, scheduler, pad_m, pad_n, pad_k + ) + + content += f"\n}} // namespace {trait}\n" + (self.output_dir / filename).write_text(content) + + def _generate_kernel_struct( + self, + pipeline: str, + epilogue: str, + scheduler: str, + pad_m: str, + pad_n: str, + pad_k: str, + ) -> str: + """Generate the code block of kernel struct""" + return f""" + +template +struct GemmKernelMultiD {{ + static constexpr bool kPadM = {pad_m}; + static constexpr bool kPadN = {pad_n}; + static constexpr bool kPadK = {pad_k}; + + static float launch(ck_tile::GemmMultiDHostArgs& args, const ck_tile::stream_config& stream) {{ + static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"}; + + static constexpr bool TransposeC = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = + ck_tile::TileGemmTraits; + + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}; + + const ck_tile::index_t k_grain = args.k_batch * TileK; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{ + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = {SCHEDULER_MAP[scheduler]}; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = {PIPELINE_MAP[pipeline][1]}; + {EPILOGUE_MAP[epilogue]} + using Kernel = ck_tile::GemmKernelMultiD; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + {{ + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + }} + + if(stream.log_level_ > 0) + {{ + std::cout << "Launching kernel with args:" + << " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" + << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" + << std::endl; + }} + + ave_time = ck_tile::launch_kernel(stream, + ck_tile::make_kernel( + Kernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + + }}; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{ + if(args.k_batch == 1) {{ + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{{}}); + }} else {{ + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{{}}); + }} + }}; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + + return ave_time; + }} + + static std::string get_name() {{ + return std::string("gemm_multi_d_") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) + + "_" + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + "_" + + std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + "_" + + "{pad_m}" + "_" + + "{pad_n}" + "_" + + "{pad_k}" + "_" + + "{pipeline}" + "_" + + "{epilogue}" + "_" + + "{scheduler}"; + }} +}}; +""" + + def _generate_instantiation_source_files(self): + """Generate kernel instance instantiation source files""" + tile_map = {} + for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): + for tile in tile_valid_params: + for ( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) in tile: + key = f"{tile_m}x{tile_n}x{tile_k}x{warp_m}x{warp_n}x{warp_k}" + value = f"{warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + if key not in tile_map: + tile_map[key] = set() + tile_map[key].add(value) + + files_listed = 0 + for trait, _ in self.valid_trait_tile_combinations.items(): + for block_tile, warp_tiles in tile_map.items(): + tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map( + int, block_tile.split("x") + ) + + content = f""" +#include "gemm_multi_d_{trait}.hpp" + +""" + for warp_tile in warp_tiles: + warp_tile_m, warp_tile_n, warp_tile_k = map( + int, warp_tile.split("x") + ) + + files_listed = files_listed + 1 + content = ( + content + + f""" +template struct {trait}::GemmKernelMultiD<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}>;""" + ) + content += """ +""" + ( + self.output_dir + / f"gemm_multi_d_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp" + ).write_text(content) + print(f"Generated {files_listed} kernel instances in total.") + + def _generate_common_instance_header_file(self): + """Generate common instance header into file.""" + content = """ +#pragma once +""" + for trait in self.valid_trait_names: + content += f'#include "gemm_multi_d_{trait}.hpp"\n' + (self.output_dir / "gemm_multi_d_instances.hpp").write_text(content) + + def _generate_dispatcher_file(self): + """Generate the code block of dispatch mechanism.""" + content = """ +#pragma once + +#include +#include +#include + +#include "gemm_multi_d_common.hpp" +#include "gemm_multi_d_instances.hpp" + +/// @brief Defines the configuration parameters for a GEMM Multi D operation, enabling the selection of a +/// specific kernel instance based on the provided settings. +struct KernelTraits +{ + /// @brief The name of the pipeline. + std::string pipeline; + /// @brief The name of the scheduler (e.g., "intrawave", "interwave"). + std::string scheduler; + /// @brief The name of the epilogue (e.g., "cshuffle", "default"). + std::string epilogue; + /// @brief Indicates whether padding is applied to the M dimension. + bool pad_m; + /// @brief Indicates whether padding is applied to the N dimension. + bool pad_n; + /// @brief Indicates whether padding is applied to the K dimension. + bool pad_k; +}; + +struct GemmMultiDDispatcher { + static auto& get_kernel_map() { + // Use a static local variable + static std::unordered_map< + std::string, + std::vector(ck_tile::GemmMultiDHostArgs&, const ck_tile::stream_config&)>>> + kernel_map; + return kernel_map; + } + + static void init() { + auto& kernel_map = get_kernel_map(); + if(!kernel_map.empty()) return; + \n""" + + for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): + content += f""" kernel_map["{trait}"] = {{""" + for _, tile in enumerate(tile_valid_params): + for j in range(len(tile)): + ( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) = tile[j] + content += """[=](ck_tile::GemmMultiDHostArgs& args, const ck_tile::stream_config& stream) { """ + + content += f""" + return run_kernel<{trait}::GemmKernelMultiD<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}>>(args, stream);""" + + if j == len(tile) - 1: + content += """ + } """ + else: + content += """ + }, """ + content += """ + };\n """ + + content += """ } + + template + static std::tuple run_kernel(ck_tile::GemmMultiDHostArgs& args, const ck_tile::stream_config& stream) + { + std::string name = Kernel::get_name(); + float avg_time = Kernel::launch(args, stream); + + return std::make_tuple(name, avg_time); + } + + + static auto dispatch(const KernelTraits& trait) { + init(); + const std::string key = assemble_key(trait); + auto& kernel_map = get_kernel_map(); + if(auto it = kernel_map.find(key); it != kernel_map.end()) + { + return it->second; + } + throw std::runtime_error("No suitable kernel found: " + key); + } + +private: + static std::string assemble_key(const KernelTraits &trait) { + return std::string(trait.pipeline) + "_" + + trait.epilogue + "_" + + trait.scheduler + "_" + + (trait.pad_m ? "true" : "false") + "_" + + (trait.pad_n ? "true" : "false") + "_" + + (trait.pad_k ? "true" : "false"); + } +}; + +""" + (self.output_dir / "gemm_multi_d_dispatcher.hpp").write_text(content) + + +def do_list_blobs( + args: argparse.Namespace, user_provide_config: Optional[JsonConfig] = None +): + generator = GemmMultiDCodeGenerator(args, user_provide_config) + generator.list_all_trait_names() + + +def do_gen_blobs( + args: argparse.Namespace, user_provide_config: Optional[JsonConfig] = None +): + generator = GemmMultiDCodeGenerator(args, user_provide_config) + generator.generate_all_instance_files() + + +def main(args): + gemm_multi_d_config = JsonConfig.from_json(args.config_json) + + if args.list_blobs: + do_list_blobs(args, gemm_multi_d_config) + elif args.gen_blobs: + do_gen_blobs(args, gemm_multi_d_config) + else: + logging.warning( + "No mode specified (use --list_blobs or --gen_blobs). Generating by default..." + ) + do_gen_blobs(args, gemm_multi_d_config) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK gemm multi D kernel", + ) + parser.add_argument( + "-w", + "--working_path", + default="./", + required=False, + help="The path where all the blobs are going to be generated", + ) + parser.add_argument( + "-j", + "--config_json", + required=False, + help="Path to the json which contains the configurations that user provide", + ) + parser.add_argument( + "-d", + "--datatype", + required=True, + help="Specify what datatype to use for the kernel generation, e.g. fp16", + ) + parser.add_argument( + "-ly", + "--layout", + required=True, + help="Specify what layout to use for the kernel generation, e.g. rcrr, rrrr", + ) + parser.add_argument( + "-ef", + "--elementwise_function", + required=True, + help="Specify what element wise function for D, e.g. mul, add, passthrough", + ) + parser.add_argument( + "-l", + "--list_blobs", + action="store_true", + help="List all kernel instances to file", + ) + parser.add_argument( + "-g", + "--gen_blobs", + action="store_true", + help="Generate all kernel instances into different files", + ) + + args = parser.parse_args() + + main(args) diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp new file mode 100644 index 0000000000..0106d76c05 --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp @@ -0,0 +1,278 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck_tile/host/device_prop.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "benchmark_gemm_multi_d.hpp" + +class GemmMultiDProfiler +{ + public: + static GemmMultiDProfiler& instance(Setting setting) + { + static GemmMultiDProfiler instance{setting}; + return instance; + } + + void benchmark( + GemmMultiDProblem& gemm_multi_d_problem, + std::vector( + ck_tile::GemmMultiDHostArgs&, const ck_tile::stream_config&)>>& + callables) + { + const ALayout layout_a = ALayout{}; + const BLayout layout_b = BLayout{}; + const D0Layout layout_d0 = D0Layout{}; + const D1Layout layout_d1 = D1Layout{}; + const ELayout layout_e = ELayout{}; + + gemm_multi_d_problem.stride_a_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_, + gemm_multi_d_problem.k_, + gemm_multi_d_problem.stride_a_, + is_row_major(layout_a)); + gemm_multi_d_problem.stride_b_ = ck_tile::get_default_stride(gemm_multi_d_problem.k_, + gemm_multi_d_problem.n_, + gemm_multi_d_problem.stride_b_, + is_row_major(layout_b)); + gemm_multi_d_problem.stride_d0_ = + ck_tile::get_default_stride(gemm_multi_d_problem.m_, + gemm_multi_d_problem.n_, + gemm_multi_d_problem.stride_d0_, + is_row_major(layout_d0)); + gemm_multi_d_problem.stride_d1_ = + ck_tile::get_default_stride(gemm_multi_d_problem.m_, + gemm_multi_d_problem.n_, + gemm_multi_d_problem.stride_d1_, + is_row_major(layout_d1)); + gemm_multi_d_problem.stride_e_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_, + gemm_multi_d_problem.n_, + gemm_multi_d_problem.stride_e_, + is_row_major(layout_e)); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_, + gemm_multi_d_problem.k_, + gemm_multi_d_problem.stride_a_, + is_row_major(layout_a))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(gemm_multi_d_problem.k_, + gemm_multi_d_problem.n_, + gemm_multi_d_problem.stride_b_, + is_row_major(layout_b))); + ck_tile::HostTensor d0_m_n( + ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_, + gemm_multi_d_problem.n_, + gemm_multi_d_problem.stride_d0_, + is_row_major(layout_d0))); + ck_tile::HostTensor d1_m_n( + ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_, + gemm_multi_d_problem.n_, + gemm_multi_d_problem.stride_d1_, + is_row_major(layout_d1))); + ck_tile::HostTensor e_m_n_device_result( + ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_, + gemm_multi_d_problem.n_, + gemm_multi_d_problem.stride_e_, + is_row_major(layout_e))); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.mData.data()); + b_k_n_dev_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_dev_buf.ToDevice(d0_m_n.mData.data()); + d1_m_n_dev_buf.ToDevice(d1_m_n.mData.data()); + + e_m_n_dev_buf.SetZero(); + e_m_n_device_result.SetZero(); + + std::array ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(), + d1_m_n_dev_buf.GetDeviceBuffer()}; + + std::array stridesDs = { + gemm_multi_d_problem.stride_d0_, gemm_multi_d_problem.stride_d1_}; + + ck_tile::GemmMultiDHostArgs gemm_multi_d_args = { + a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + gemm_multi_d_problem.split_k_, + gemm_multi_d_problem.m_, + gemm_multi_d_problem.n_, + gemm_multi_d_problem.k_, + gemm_multi_d_problem.stride_a_, + gemm_multi_d_problem.stride_b_, + stridesDs, + gemm_multi_d_problem.stride_e_, + }; + + ck_tile::HostTensor e_m_n_host_result( + ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_, + gemm_multi_d_problem.n_, + gemm_multi_d_problem.stride_e_, + is_row_major(layout_e))); + + if(setting_.verify_) + { + gemm_multi_d_host_reference( + setting_.verify_, a_m_k, b_k_n, d0_m_n, d1_m_n, e_m_n_host_result); + } + + for(auto& callable : callables) + { + auto kernel_run_result = + callable(gemm_multi_d_args, + ck_tile::stream_config{ + nullptr, true, setting_.log_, setting_.n_warmup_, setting_.n_repeat_}); + + auto [kernel_name, execution_time] = kernel_run_result; + + process_result(gemm_multi_d_problem, + e_m_n_dev_buf, + e_m_n_host_result, + e_m_n_device_result, + kernel_run_result); + } + } + + void process_result(const GemmMultiDProblem& gemm_multi_d_problem, + ck_tile::DeviceMem& e_m_n_dev_buf, + ck_tile::HostTensor& e_m_n_host_result, + ck_tile::HostTensor& e_m_n_dev_result, + const std::tuple& kernel_run_result) + { + auto [name, avg_time] = kernel_run_result; + + KernelInstance kernel_instance{name, gemm_multi_d_problem, {-1.0f, -1.0f, -1.0f}}; + + static constexpr ck_tile::index_t NumDTensor = DsDataType::size(); + std::size_t flop = 0, num_byte = 0; + flop += std::size_t(2) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_ * + gemm_multi_d_problem.k_; + ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) { + num_byte += sizeof(ck_tile::remove_cvref_t>) * + gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_; + flop += sizeof(ck_tile::remove_cvref_t>) * + gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_; + }); + num_byte += sizeof(ADataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.k_ + + sizeof(BDataType) * gemm_multi_d_problem.k_ * gemm_multi_d_problem.n_ + + sizeof(EDataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_; + + kernel_instance.perf_result_.latency_ = avg_time; + kernel_instance.perf_result_.tflops_ = static_cast(flop) / 1.E9 / avg_time; + kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time; + + if(setting_.log_ > 0) + { + std::cout << kernel_instance << std::endl; + } + + e_m_n_dev_buf.FromDevice(e_m_n_dev_result.data()); + bool verified_correct = + !setting_.verify_ || + compare(name, gemm_multi_d_problem.k_, e_m_n_dev_result, e_m_n_host_result); + + if(verified_correct) + { + kernel_instances_.emplace_back(kernel_instance); + } + else + { + std::cout << "Verification failed, skip kernel: " << name << std::endl; + } + + e_m_n_dev_buf.SetZero(); + e_m_n_dev_result.SetZero(); + } + + KernelInstance select_best_instance(Metric metric) + { + if(kernel_instances_.empty()) + throw std::runtime_error("Empty instances"); + + auto kernel_instance = *std::max_element(kernel_instances_.begin(), + kernel_instances_.end(), + [metric](const auto& a, const auto& b) { + return PerformanceResult::compare( + b.perf_result_, a.perf_result_, metric); + }); + + std::cout << "**********************************" << std::endl; + std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" + << "The best kernel instance is: " << kernel_instance << std::endl; + std::cout << "**********************************" << std::endl; + + if(!setting_.csv_filename_.empty()) + { + std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app); + + if(!file.is_open()) + { + std::cerr << "Warning: Failed to open CSV file for writing." << std::endl; + } + else + { + if(file.tellp() == 0) + { + file << "rocm_version,device_name," + << "split_k,m,n,k,stride_a,stride_b,stride_c," + << "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c," + << "structured_sparsity," << "name," + << "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n"; + } + + const auto& problem = kernel_instance.problem_; + const auto& name = kernel_instance.name_; + const auto& perf = kernel_instance.perf_result_; + + file << get_rocm_version() << "," << ck_tile::get_device_name() << "," + << problem.split_k_ << "," << problem.m_ << "," << problem.n_ << "," + << problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << "," + << problem.stride_d0_ << "," << problem.stride_d1_ << "," << problem.stride_e_ + << "," << problem.dtype_a_ << "," << problem.dtype_b_ << "," + << problem.dtype_d0_ << "," << problem.dtype_d1_ << "," << problem.dtype_acc_ + << "," << problem.dtype_e_ << "," << problem.layout_a_ << "," + << problem.layout_b_ << "," << problem.layout_d0_ << "," << problem.layout_d1_ + << "," << problem.layout_e_ << "," << "," << name << "," << std::fixed + << std::setprecision(4) << perf.latency_ << "," << std::fixed + << std::setprecision(4) << perf.tflops_ << "," << std::fixed + << std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric) + << "\n"; + + if(!file) + { + std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl; + } + } + } + + return kernel_instance; + } + + GemmMultiDProfiler(const GemmMultiDProfiler&) = delete; + GemmMultiDProfiler& operator=(const GemmMultiDProfiler&) = delete; + + private: + ~GemmMultiDProfiler() { kernel_instances_.clear(); } + GemmMultiDProfiler(Setting setting) : setting_(setting) {} + + Setting setting_; + + std::vector kernel_instances_; +}; From 0f42a92fc127f727e004d867eb2cc5177f626143 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Tue, 12 Aug 2025 18:23:34 -0700 Subject: [PATCH 9/9] Finish the grouped gemm restructure with fp8 data type (#2655) * Finish the grouped gemm restructure with data type * restore gemm_utils.hpp * Update example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Comment Addressed --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 105 ++++--------- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 148 +++++++++++++++--- .../run_grouped_gemm_example.inc | 115 +++++++++++--- 3 files changed, 251 insertions(+), 117 deletions(-) mode change 100644 => 100755 example/ck_tile/17_grouped_gemm/grouped_gemm.cpp diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp old mode 100644 new mode 100755 index 897952f03c..a821af0649 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -16,91 +16,50 @@ #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" -template +template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr, bool splitk) { -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Memory friendly for Interwave scheduler - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 32; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 4; - constexpr ck_tile::index_t N_Warp = 1; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; - - constexpr bool DoubleSmemBuffer = false; -#endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) - // Compute friendly for Intrawave scheduler - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; - - constexpr bool DoubleSmemBuffer = false; -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) - // Compute friendly for Intrawave scheduler - // Using the ping pong reader in the lds level - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; - - constexpr bool DoubleSmemBuffer = true; -#endif - constexpr bool kPadM = false; constexpr bool kPadN = false; constexpr bool kPadK = false; - constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits; + using GemmUniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; float ave_time{0}; const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; // We create the GEMM pipeline without specifying hotloop or tailnumber. @@ -112,7 +71,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, GemmUniversalTraits, scheduler>; - using GemmPipeline = GEMM_PIPELINE; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GroupedGemmKernel; @@ -145,7 +105,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, grids, blocks, @@ -173,4 +133,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, #include "run_grouped_gemm_example.inc" constexpr bool Persistent = true; -int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + return !run_grouped_gemm_example(argc, argv); +} diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 89d91fbef6..e992cb3118 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -15,24 +15,26 @@ #define CK_TILE_PIPELINE_COMPUTE_V4 3 #ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V4 +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 #endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; #else -#error "unsupported CK_TILE_PIPELINE_DEFAULT value" + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; #endif +} template struct GemmTypeConfig; @@ -46,13 +48,109 @@ struct GemmTypeConfig using AccDataType = float; }; -using Types = GemmTypeConfig; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; -// Specific type aliases for easy access -using ADataType = Types::ADataType; -using BDataType = Types::BDataType; -using AccDataType = Types::AccDataType; -using CDataType = Types::CDataType; +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 1; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; @@ -69,6 +167,7 @@ auto create_args(int argc, char* argv[]) .insert("b_layout", "C", "B tensor data layout - Row by default.") .insert("c_layout", "R", "C tensor data layout - Row by default.") .insert("validate", "1", "0. No validation, 1. Validation on CPU.") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") @@ -98,7 +197,14 @@ float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr); -template +template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr, diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index fa7f1a31c1..425299203f 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -10,6 +10,7 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } +template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -30,7 +31,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template ( - stream, group_count, kargs_ptr, splitk); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr, splitk); } std::string op_name{"Grouped Gemm"}; @@ -127,7 +135,15 @@ float invoke_gemm(int n_warmup, return ave_time; } -template +template int run_grouped_gemm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -243,7 +259,8 @@ int run_grouped_gemm_example_with_layouts(int argc, {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } - invoke_gemm, AccDataType, @@ -271,7 +288,9 @@ int run_grouped_gemm_example_with_layouts(int argc, a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value); + const auto rtol_atol = + calculate_rtol_atol( + Ks[i], kbatch, max_accumulated_value); pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref, "Error: Incorrect results!", @@ -288,7 +307,61 @@ int run_grouped_gemm_example_with_layouts(int argc, return pass; } -template +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using Types = GemmTypeConfig; + // Specific type aliases for easy access + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using AccDataType = typename Types::AccDataType; + using CDataType = typename Types::CDataType; + + if(a_layout == "R" && b_layout == "C") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "R" && b_layout == "R") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + +template typename GemmConfig> int run_grouped_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -297,30 +370,22 @@ int run_grouped_gemm_example(int argc, char* argv[]) return -1; } - const std::string a_layout = arg_parser.get_str("a_layout"); - const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string data_type = arg_parser.get_str("prec"); - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if(a_layout == "R" && b_layout == "C") + if(data_type == "fp16") { - return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + return run_gemm_example_prec_type, ck_tile::half_t>( + a_layout, b_layout, argc, argv); } - else if(a_layout == "R" && b_layout == "R") + else if(data_type == "fp8") { - return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } - else if(a_layout == "C" && b_layout == "R") - { - return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - } - else if(a_layout == "C" && b_layout == "C") - { - return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + return run_gemm_example_prec_type, ck_tile::fp8_t>( + a_layout, b_layout, argc, argv); } else { - throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + throw std::runtime_error("Unsupported data type configuration."); } }