From d40ea30e94cf76154ff411a968d1aef347aa40a9 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Tue, 12 Aug 2025 13:22:05 +0000 Subject: [PATCH] Merge commit '5b39de4bb61a3f0399fcd384f3a82c5e6ce28e5e' into develop --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 - .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 147 +- example/ck_tile/01_fmha/fmha_fwd.cpp | 2 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 3 - .../ck_tile/01_fmha/script/benchmark_fwd.sh | 11 + .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 21 +- .../core/arch/amd_buffer_addressing.hpp | 17 +- .../arch/amd_buffer_addressing_builtins.hpp | 17 +- include/ck_tile/core/arch/arch.hpp | 27 +- include/ck_tile/core/arch/utility.hpp | 15 - include/ck_tile/core/config.hpp | 10 - include/ck_tile/core/numeric/bfloat16.hpp | 11 - include/ck_tile/core/numeric/pk_fp4.hpp | 2 +- include/ck_tile/core/numeric/pk_int4.hpp | 2 +- include/ck_tile/core/numeric/vector_type.hpp | 12 +- .../unary_element_wise_operation.hpp | 7 + include/ck_tile/ops/fmha.hpp | 2 - .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 1530 +++++------------ ...block_fmha_bwd_pipeline_default_policy.hpp | 24 +- .../pipeline/block_fmha_pipeline_enum.hpp | 7 - .../pipeline/block_fmha_pipeline_problem.hpp | 2 - ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 1177 ------------- ..._pipeline_qr_ks_vs_async_trload_policy.hpp | 823 --------- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 30 +- .../block/block_gemm_areg_breg_creg_v1.hpp | 178 +- .../ops/gemm/block/block_gemm_problem.hpp | 9 +- .../gemm_pipeline_ag_bg_cr_scheduler.hpp | 6 - .../gemm/pipeline/gemm_pipeline_problem.hpp | 48 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 8 - .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 4 - .../ck_tile/ops/reduce/block/block_reduce.hpp | 30 +- 31 files changed, 639 insertions(+), 3545 deletions(-) delete mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp delete mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 42a9d5148a..6fca800c90 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -115,7 +115,6 @@ PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", - "qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", } PIPELINE_ENUM_MAP = { @@ -124,7 +123,6 @@ PIPELINE_ENUM_MAP = { "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index ce35c6a2a7..269af4e6a7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -12,7 +12,6 @@ from typing import List, Optional, Tuple from codegen.cmake_config import * from codegen.cpp_symbol_map import * -from codegen.utils import update_file DTYPE_BITS = { @@ -84,7 +83,6 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< {F_mode}, fmha_variant_{F_idx}, fmha_mask_{F_idx}, - {F_trload}, fmha_trait_{F_idx}>; using fmha_pipeline_{F_idx} = {F_pipeline}< @@ -99,7 +97,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; #include @@ -163,19 +161,12 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; - - const bool has_load_tr = ck_tile::is_load_tr_supported(); {F_dispatch} return r; }} """ -FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ -{F_dtype_case} - }} -""" - FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} @@ -186,8 +177,8 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < """ FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && - ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; return fmha_fwd_(s, a); }} """ @@ -230,7 +221,6 @@ class FmhaFwdApiTrait: dpad : str dvpad : str skip : str - tr_load : str constraint : CppConstraint @property @@ -241,19 +231,13 @@ class FmhaFwdApiTrait: @property def scheck(self) -> str: if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag in ['qr_async', 'qr_async_trload']: + if self.pipeline_tag == 'qr_async': if self.spad == 't' : return 'true' # always support else : return 'true' elif self.pipeline_tag in ['qr', 'qs']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False - - @property - def seqtune(self) -> str: - if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true - else: - return f'a.seqlen_q <= {self.bm0}' @property def skcheck(self) -> str: @@ -264,9 +248,6 @@ class FmhaFwdApiTrait: elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag == 'qr_async_trload': - if self.skpad == 't' : return 'true' - else: return 'true' else: assert False @property @@ -275,7 +256,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + elif self.pipeline_tag in ['qr', 'qs']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {bk0submax} == 0' @@ -287,7 +268,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + elif self.pipeline_tag in ['qr', 'qs']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {bk0submax} == 0' @@ -309,7 +290,6 @@ class FmhaFwdPipeline: F_squant : str # F_mask : str # value from MASK_MAP F_skip : str # true/false - F_trload : str # true/false F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -351,9 +331,6 @@ class FmhaFwdPipeline: if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' - - if self.F_trload == 't' : n += '_trload' - else: n += '_ntrload' return n @@ -374,39 +351,31 @@ class FmhaFwdApiPool: @property def api(self) -> str: - tr_load_cond_map = { - "t": "has_load_tr", - "f": "true" - } - - per_tr_load =str() - for tr_load in ["t", "f"]: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][(hdim, hdim_v)] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes) - if not per_tr_load: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][(hdim, hdim_v)] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: # empty string we add some ignore to suppress warning in api - per_tr_load += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load) + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) @dataclass class FmhaFwdTileSize: @@ -489,8 +458,7 @@ class FmhaFwdKernel: F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], - F_trload = BOOL_MAP[self.F_pipeline.F_trload]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) @property def name(self) -> str: @@ -526,7 +494,6 @@ class FmhaFwdKernel: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, skip=self.F_pipeline.F_skip, - tr_load=self.F_pipeline.F_trload, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) class KernelComponentFactory: @@ -536,15 +503,10 @@ class KernelComponentFactory: def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], @@ -572,27 +534,34 @@ class KernelComponentFactory: if dtype in ['fp16', 'bf16']: for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256 and hdim_v == 256: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + # if True: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": - pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) - pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f')) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -630,12 +599,6 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue - if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): - continue # logits_soft_cap is only allowed if no bias if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): continue @@ -702,10 +665,10 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl return (api_pool, gen) def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - update_file(autogen_dir / kernel.filename, kernel.template) + (autogen_dir / kernel.filename).write_text(kernel.template) def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: - update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index d0f8e3798c..c0e4dc3d30 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1135,7 +1135,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec - << " GB/s" << std::flush << std::endl; + << " GB/s" << std::flush; if(do_validation == 0) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index df1e9e5699..81dda692ea 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/host/device_prop.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/fmha.hpp" @@ -1029,7 +1028,6 @@ template struct fmha_fwd_traits_ { @@ -1054,7 +1052,6 @@ struct fmha_fwd_traits_ static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; - static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; }; diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 88c16cceb6..599c595a75 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -18,3 +18,14 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn done done done + +for perm in 0 1 ; do + +$EXE -prec=fp8 -squant=1 -b=32 -h=16 -d=128 -s=512 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=16 -h=16 -d=128 -s=1024 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=8 -h=16 -d=128 -s=2048 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=4 -h=16 -d=128 -s=4096 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=2 -h=16 -d=128 -s=8192 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=1 -h=16 -d=128 -s=16384 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 + +done \ No newline at end of file diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index dc2be933bd..b867cd6c07 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -42,6 +42,7 @@ run_fp16_bf16_tests() { for prec in "fp16" "bf16" ; do for mode in 1 0 ; do for perm in 0 1 ; do + for vlayout in "r" "c" ; do for hdim in 32 64 128 256 ; do for lse in 0 1 ; do for bias in "n" "e" "a" ; do @@ -50,16 +51,16 @@ run_fp16_bf16_tests() { for page_block_size in $PAGE_BLOCK_SIZE ; do for cache_batch_idx in $CACHE_BATCH_IDX ; do - # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done ; done ; done diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 07be65a150..35da19cd3e 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -41,6 +41,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); + r.x = __builtin_amdgcn_readfirstlane(r.x); + r.y = __builtin_amdgcn_readfirstlane(r.y); + r.z = __builtin_amdgcn_readfirstlane(r.z); + r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } @@ -1314,17 +1318,6 @@ enum struct amd_buffer_coherence_enum glc = 1, slc = 2, glc_slc = 3, - // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 - // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system - // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse - WAVE_NT0 = 0, - WAVE_NT1 = 2, - GROUP_NT0 = 1, - GROUP_NT1 = 3, - DEVICE_NT0 = 8, - DEVICE_NT1 = 10, - SYSTEM_NT0 = 9, - SYSTEM_NT1 = 11, }; template & src_thread_ #if defined(__gfx950__) template -__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) +__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) { static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index c64b296408..8c3bc0bc36 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -32,6 +32,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); + r.x = __builtin_amdgcn_readfirstlane(r.x); + r.y = __builtin_amdgcn_readfirstlane(r.y); + r.z = __builtin_amdgcn_readfirstlane(r.z); + r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } @@ -1182,17 +1186,6 @@ enum struct amd_buffer_coherence_enum glc = 1, slc = 2, glc_slc = 3, - // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 - // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system - // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse - WAVE_NT0 = 0, - WAVE_NT1 = 2, - GROUP_NT0 = 1, - GROUP_NT1 = 3, - DEVICE_NT0 = 8, - DEVICE_NT1 = 10, - SYSTEM_NT0 = 9, - SYSTEM_NT1 = 11, }; template -__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) +__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) { static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index f0e9518120..ab42ec8617 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -89,6 +89,21 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; } CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; } +CK_TILE_DEVICE void block_sync_lds() +{ +#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM + // asm volatile("\ + // s_waitcnt lgkmcnt(0) \n \ + // s_barrier \ + // " ::); + + __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_barrier(); +#else + __syncthreads(); +#endif +} + CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0) { #ifdef __gfx12__ @@ -159,18 +174,6 @@ CK_TILE_DEVICE void s_waitcnt_barrier() __builtin_amdgcn_s_barrier(); } -template -CK_TILE_DEVICE void block_sync_lds() -{ - s_waitcnt_barrier(); -} - -template -CK_TILE_DEVICE void block_sync_lds_direct_load() -{ - s_waitcnt_barrier(); -} - CK_TILE_DEVICE void s_nop(index_t cnt = 0) { #if 1 diff --git a/include/ck_tile/core/arch/utility.hpp b/include/ck_tile/core/arch/utility.hpp index 93008f8525..7184f99521 100644 --- a/include/ck_tile/core/arch/utility.hpp +++ b/include/ck_tile/core/arch/utility.hpp @@ -59,21 +59,6 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta) #endif } -template -CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local) -{ - static_assert(sizeof(T) == sizeof(int32_t), "wrong!"); - - const int32x2_t x = __builtin_amdgcn_permlane32_swap( - bit_cast(v_local), bit_cast(v_local), false, false); - - thread_buffer v; - v(0) = bit_cast(x[0]); - v(1) = bit_cast(x[1]); - - return v; -} - template CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane) { diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index e472bd01e5..c471f416c3 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -191,16 +191,6 @@ #endif #endif -// use llvm builtin bf16 data type after ROCm 6.5 -#ifndef CK_TILE_USE_LLVM_BUILTIN_BF16 -#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \ - (HIP_VERSION_MAJOR >= 7) -#define CK_TILE_USE_LLVM_BUILTIN_BF16 1 -#else -#define CK_TILE_USE_LLVM_BUILTIN_BF16 0 -#endif -#endif - #ifndef CK_TILE_DEBUG_LOG #define CK_TILE_DEBUG_LOG 0 #endif diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 245fb7244f..6f31468809 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -6,9 +6,6 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/numeric.hpp" -#if CK_TILE_USE_LLVM_BUILTIN_BF16 -#include -#endif #include #pragma once @@ -105,11 +102,7 @@ struct native_t using bf16_t = bfloat16_t; using bf16_raw_t = typename bf16_t::raw_type; #else -#if CK_TILE_USE_LLVM_BUILTIN_BF16 -using bfloat16_t = __bf16; -#else using bfloat16_t = ushort; -#endif using bf16_t = bfloat16_t; using bf16_raw_t = uint16_t; #endif @@ -287,11 +280,7 @@ template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant = {}) { -#if defined(__gfx950__) - return static_cast(f); -#else return bit_cast(float_to_bf16_raw(f, constant{})); -#endif } template using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); -using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); +using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) { diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index bbd3d53827..58bdb43b08 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -131,12 +131,12 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64))); // bf16 // using bf16_t = ... -using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); -using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4))); -using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8))); -using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16))); -using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32))); -using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64))); +using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); +using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4))); +using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8))); +using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16))); +using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32))); +using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64))); // i32 // using int32_t = ... diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index b69c167315..0e385901ed 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -330,6 +330,13 @@ struct PassThrough y = type_convert(x); } + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const + { + y = type_convert(x); + } + template <> CK_TILE_HOST_DEVICE void operator()(float& y, const ck_tile::fp16_t& x) const diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 69f645b850..d8dd5db12e 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -52,8 +52,6 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 5b3d38d3e7..8d257a3329 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -13,7 +13,6 @@ #include #include -#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] @@ -62,14 +61,6 @@ struct FmhaFwdKernel static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; - static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad; -#if defined(__gfx950__) - static constexpr bool kIsAvialable = true; -#else - static constexpr bool kIsAvialable = !kUseTrLoad; -#endif - static constexpr std::string_view kPipelineName = FmhaPipeline::name; - // clang-format off template struct t2s; template <> struct t2s { static constexpr const char * name = "fp32"; }; @@ -109,7 +100,7 @@ struct FmhaFwdKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload"); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); #undef _SS_ #undef _TS_ // clang-format on @@ -1045,1142 +1036,455 @@ struct FmhaFwdKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { - if constexpr(kIsAvialable) - run_(std::move(kargs)); - } + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; - CK_TILE_DEVICE void run_(Kargs kargs) const - { - if constexpr(kPipelineName != "qr_async_trload") + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsGroupMode) { - // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - // divide problem - const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); - const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; - - if constexpr(kIsGroupMode) + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) { - // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - if constexpr(std::is_same_v) - { - batch_offset_v = key_start * kargs.stride_v; - } - else - { - batch_offset_v = key_start; - } - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = query_start * kargs.stride_bias; - } - if constexpr(kStoreLSE) - { - batch_offset_lse = query_start; - } - if constexpr(kHasDropout) - { - batch_offset_randval = query_start * kargs.stride_randval; - } - batch_offset_o = query_start * kargs.stride_o; - - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; - - if constexpr(kSkipMinSeqlenQ) - { - if(kargs.seqlen_q <= kargs.min_seqlen_q) - { - return; - } - } - - // # of required blocks is different in each groups, terminate unnecessary blocks - // earlier - if(kargs.seqlen_q <= i_m0) - { - return; - } - - if(kargs.seqlen_k_ptr != nullptr) - { - kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; - } - else - { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; - } + batch_offset_v = key_start * kargs.stride_v; } else { - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + batch_offset_v = key_start; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + if constexpr(kSkipMinSeqlenQ) + { + if(kargs.seqlen_q <= kargs.min_seqlen_q) { - batch_offset_bias = - static_cast(i_batch) * kargs.batch_stride_bias; + return; } - if constexpr(kStoreLSE) - { - batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; - } - if constexpr(kHasDropout) - { - batch_offset_randval = - static_cast(i_batch) * kargs.batch_stride_randval; - } - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } - // for simplicity, batch stride we just modify the pointer - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_q + - batch_offset_q; - const KDataType* k_ptr = - reinterpret_cast(kargs.k_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + - batch_offset_k; - const VDataType* v_ptr = - reinterpret_cast(kargs.v_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + - batch_offset_v; - ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_o + - batch_offset_o; + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } - // Q/K/V DRAM and DRAM window - const auto q_dram = [&]() { - const auto q_dram_naive = make_naive_tensor_view( - q_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_q, 1), - number{}, - number<1>{}); - if constexpr(FmhaPipeline::kQLoadOnce) - { - return pad_tensor_view(q_dram_naive, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - const auto k_dram = [&]() { - const auto k_dram_naive = make_naive_tensor_view( - k_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), - number{}, - number<1>{}); - - constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; - return pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - const auto v_dram = [&]() { - if constexpr(std::is_same_v) - { - const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_v), - make_tuple(kargs.stride_v, 1), - number{}, - number<1>{}); - - const auto v_dram_transposed = transform_tensor_view( - v_dram_naive, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_k)), - make_tuple(sequence<1>{}, sequence<0>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; - return pad_tensor_view( - v_dram_transposed, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.hdim_v, kargs.seqlen_k), - make_tuple(kargs.stride_v, 1), - number{}, - number<1>{}); - - constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; - return pad_tensor_view( - v_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - - auto q_dram_window = make_tile_window( - q_dram, - [&]() { - if constexpr(FmhaPipeline::kQLoadOnce) - return make_tuple(number{}, - number{}); - else - return make_tuple(number{}, number{}); - }(), - {i_m0, 0}); - - auto k_dram_window = make_tile_window( - k_dram, - make_tuple(number{}, number{}), - {0, 0}); - - auto v_dram_window = make_tile_window( - v_dram, - make_tuple(number{}, number{}), - {i_n1, 0}); - /// FIXME: Before C++20, capturing structured binding variables are not supported. - /// Remove following copy capture of the 'i_nhead' if in C++20 - const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto bias_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - const BiasDataType* bias_ptr = - reinterpret_cast(kargs.bias_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_bias + - batch_offset_bias; - - const auto bias_dram = [&]() { - const auto bias_dram_naive = - make_naive_tensor_view( - bias_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_bias, 1), - number{}, - number<1>{}); - - return pad_tensor_view(bias_dram_naive, - bias_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(bias_dram_window_lengths); - } - }(); - - // lse - auto lse_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto lse_dram_window_lengths = make_tuple(number{}); - if constexpr(kStoreLSE) - { - LSEDataType* lse_ptr = - reinterpret_cast(kargs.lse_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_lse + - batch_offset_lse; - - const auto lse_dram = [&]() { - const auto lse_dram_naive = - make_naive_tensor_view( - lse_ptr, - make_tuple(kargs.seqlen_q), - make_tuple(1), - number<1>{}, - number<1>{}); - - return pad_tensor_view( - lse_dram_naive, lse_dram_window_lengths, sequence{}); - }(); - - return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); - } - else - { - return make_null_tile_window(lse_dram_window_lengths); - } - }(); - - auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { - if constexpr(kHasDropout) - { - return BlockDropout{i_batch_, - i_nhead_, - kargs.num_head_q, - kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val - : *kargs.drop_seed.ptr, - kargs.is_drop_seed_offset_from_host - ? kargs.drop_offset.val - : *kargs.drop_offset.ptr, - kargs.rp_undrop, - kargs.p_undrop_in_uint8_t, - kargs.is_store_randval}; - } - else - { - return NullBlockDropout{}; - }; - }(); - - auto randval_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto randval_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(kHasDropout) - { - RandValOutputDataType* rand_val_ptr = - reinterpret_cast(kargs.rand_val_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_randval + - batch_offset_randval; - - const auto randval_dram = [&]() { - const auto randval_dram_naive = - make_naive_tensor_view( - rand_val_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_randval, 1), - number<1>{}, - number<1>{}); - - return pad_tensor_view(randval_dram_naive, - randval_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(randval_dram_window_lengths); - } - }(); - - FmhaMask mask = [&]() { - if constexpr(kHasMask) - return ck_tile::make_generic_attention_mask_from_lr_window( - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); - else - return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; - }(); - - // WA i_batch capture structure binding before c++20 - auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - // data loading, shared by entire wg - // TODO: how to use s_read? - SaccDataType slope = - *(reinterpret_cast(kargs.alibi_slope_ptr) + - i_batch_ * kargs.alibi_slope_stride + i_nhead_); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; -#endif - if constexpr(kHasMask) - { - return make_alibi_from_lr_mask(slope, - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type); - } - else - { - return Alibi{ - slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; - } - } - else - { - return EmptyPositionEncoding{}; - } - }(); - - AttentionVariant variant; - const auto variant_params = [&] { - if constexpr(kHasLogitsSoftCap) - { - return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; - } - else - { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; - } - }(); - - BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - - auto o_acc_tile = [&]() { - if constexpr(kDoFp8StaticQuant) - { - return FmhaPipeline{}( - q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - randval_dram_window, - lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{kargs.scale_p}, // p_compute_element_func - composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); - } - else - { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - randval_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); - } - }(); - - // O DRAM and O DRAM window - auto o_dram = [&]() { - const auto o_dram_naive = make_naive_tensor_view( - o_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.stride_o, 1), - number{}, - number<1>{}); - - return pad_tensor_view( - o_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - auto o_dram_window = make_tile_window( - o_dram, - make_tuple(number{}, number{}), - {i_m0, i_n1}); - - EpiloguePipeline{}(o_dram_window, o_acc_tile); + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } } else { - // TODO: Refine the logical here. - // In Decode case - // 1. we don't expect KV data reused by different ThreadGroups, bypass the cache - // 2. limit the LDS usage, as we want higher occupancy - // In Prefill case - // 1. we expect KV data reused by different ThreadGroups, use cache - // 2. use more LDS, as we want better memory latency hiding - // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the - // cache - constexpr bool PrefillCase = FmhaPipeline::kM0 >= 128; - // divide problem - const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - - const index_t i_m0 = i_tile_m * FmhaPipeline::kM0; - const index_t i_n1 = i_tile_n * FmhaPipeline::kN1; - - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; // unused for paged-kvcache - long_index_t batch_offset_v = 0; // unused for paged-kvcache - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; - // index_t kv_l2p_offset = - // 0; // logical-to-physical offset of seqlen_k coordinate. only used for - // paged-kvcache - - if constexpr(kIsGroupMode) + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - if constexpr(std::is_same_v) + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_k), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch_, + i_nhead_, + kargs.num_head_q, + kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val + : *kargs.drop_seed.ptr, + kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val + : *kargs.drop_offset.ptr, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + kargs.is_store_randval}; + } + else + { + return NullBlockDropout{}; + }; + }(); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) { - batch_offset_v = key_start * kargs.stride_v; + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); } else { - batch_offset_v = key_start; - } - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = query_start * kargs.stride_bias; - } - - batch_offset_lse = query_start; - batch_offset_o = query_start * kargs.stride_o; - - // get real # queries & # keys under group mode - kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; - - // # of required blocks is different in each groups, terminate unnecessary blocks - // earlier - if(kargs.seqlen_q <= i_m0) - { - return; - } - - if(kargs.seqlen_k_ptr != nullptr) - { - kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; - } - else - { - kargs.seqlen_k = - kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; } } else { - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - if constexpr(kStoreLSE) - { - batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; - } - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; - - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = - static_cast(i_batch) * kargs.batch_stride_bias; - } + return EmptyPositionEncoding{}; } + }(); - // for simplicity, batch stride we just modify the pointer - const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk; + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_q + - batch_offset_q; - const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + - static_cast(i_nhead_k) * kargs.nhead_stride_k + - batch_offset_k; - const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + - static_cast(i_nhead_k) * kargs.nhead_stride_v + - batch_offset_v; + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_o + - batch_offset_o; + auto o_acc_tile = [&]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{kargs.scale_p}, // p_compute_element_func + composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } + }(); - // Q/K/V DRAM and DRAM window - const auto q_dram = [&] { - const auto q_dram_naive = [&] { - { - return make_naive_tensor_view( - q_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_q, 1), - number{}, - number<1>{}); - } - }(); + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); - if constexpr(FmhaPipeline::kQLoadOnce) - { - const auto seqlen_q = kargs.seqlen_q; - const auto q_dram_pad = pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr index_t LDSLayerSize = 256 / sizeof(QDataType); - constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); - - if constexpr(XorLengthFold > 1) - { - const auto q_dram_unmerged = transform_tensor_view( - q_dram_pad, - make_tuple( - make_unmerge_transform( - make_tuple(seqlen_q / XorLengthFold, XorLengthFold)), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto q_dram_merged = transform_tensor_view( - q_dram_unmerged, - make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold), - make_merge_transform_v3_division_mod(make_tuple( - XorLengthFold, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto q_dram_unmerged_xor = transform_tensor_view( - q_dram_merged, - make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold), - make_unmerge_transform(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto q_dram_permuted = transform_tensor_view( - q_dram_unmerged_xor, - make_tuple( - make_xor_transform( - make_tuple(seqlen_q / XorLengthFold, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto q_dram_tmp = transform_tensor_view( - q_dram_permuted, - make_tuple( - make_pass_through_transform(seqlen_q / XorLengthFold), - make_unmerge_transform( - make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_view( - q_dram_tmp, - make_tuple( - make_merge_transform_v3_division_mod( - make_tuple(seqlen_q / XorLengthFold, number{})), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - const auto q_dram_unmerged = transform_tensor_view( - q_dram_pad, - make_tuple( - make_pass_through_transform(seqlen_q), - make_unmerge_transform(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto q_dram_permuted = transform_tensor_view( - q_dram_unmerged, - make_tuple( - make_xor_transform(make_tuple(seqlen_q, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_view( - q_dram_permuted, - make_tuple( - make_pass_through_transform(seqlen_q), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - } - else - { - return pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - - const auto make_k_dram = [&](const KDataType* data, index_t height) { - const auto k_dram_naive = make_naive_tensor_view( - data, // will update this pointer if using paged-kvcache - make_tuple(height, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), - number{}, - number<1>{}); - - const auto k_dram_pad = pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr index_t LDSLayerSize = 256 / sizeof(KDataType); - constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); - - if constexpr(XorLengthFold > 1) - { - const auto k_dram_unmerged = transform_tensor_view( - k_dram_pad, - make_tuple(make_unmerge_transform( - make_tuple(height / XorLengthFold, XorLengthFold)), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto k_dram_merged = transform_tensor_view( - k_dram_unmerged, - make_tuple(make_pass_through_transform(height / XorLengthFold), - make_merge_transform_v3_division_mod(make_tuple( - XorLengthFold, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto k_dram_unmerged_xor = transform_tensor_view( - k_dram_merged, - make_tuple(make_pass_through_transform(height / XorLengthFold), - make_unmerge_transform(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto k_dram_permuted = transform_tensor_view( - k_dram_unmerged_xor, - make_tuple( - make_xor_transform( - make_tuple(height / XorLengthFold, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto k_dram_tmp = transform_tensor_view( - k_dram_permuted, - make_tuple( - make_pass_through_transform(height / XorLengthFold), - make_unmerge_transform(make_tuple( - number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_view( - k_dram_tmp, - make_tuple( - make_merge_transform_v3_division_mod( - make_tuple(height / XorLengthFold, number{})), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - const auto k_dram_unmerged = transform_tensor_view( - k_dram_pad, - make_tuple( - make_pass_through_transform(height), - make_unmerge_transform(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto k_dram_permuted = transform_tensor_view( - k_dram_unmerged, - make_tuple( - make_xor_transform(make_tuple( - height, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_view( - k_dram_permuted, - make_tuple( - make_pass_through_transform(height), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - }; - const auto k_dram = [&]() { - { - return make_k_dram(k_ptr, kargs.seqlen_k); - } - }(); - - const auto make_v_dram = [&](const VDataType* data, index_t length) { - const auto v_dram_naive = make_naive_tensor_view( - data, // will update this pointer if using paged-kvcache - make_tuple(length, kargs.hdim_v), - make_tuple(kargs.hdim_v, 1), - number{}, - number<1>{}); - - // TODO: Add kVHeadDim - constexpr index_t XorGroupSize = - FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); - - const auto v_dram_pad = pad_tensor_view( - v_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr index_t LDSLayerSize = 256 / sizeof(VDataType); - constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); - - if constexpr(XorLengthFold > 1) - { - const auto v_dram_unmerged = transform_tensor_view( - v_dram_pad, - make_tuple(make_unmerge_transform( - make_tuple(length / XorLengthFold, XorLengthFold)), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto v_dram_merged = transform_tensor_view( - v_dram_unmerged, - make_tuple(make_pass_through_transform(length / XorLengthFold), - make_merge_transform_v3_division_mod(make_tuple( - XorLengthFold, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto v_dram_unmerged_xor = transform_tensor_view( - v_dram_merged, - make_tuple( - make_pass_through_transform(length / XorLengthFold), - make_unmerge_transform(make_tuple(number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto v_dram_permuted = transform_tensor_view( - v_dram_unmerged_xor, - make_tuple( - make_xor_transform(make_tuple(length / XorLengthFold, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - const auto v_dram_tmp = transform_tensor_view( - v_dram_permuted, - make_tuple(make_pass_through_transform(length / XorLengthFold), - make_unmerge_transform(make_tuple( - number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_view( - v_dram_tmp, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(length / XorLengthFold, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - const auto v_dram_unmerged = transform_tensor_view( - v_dram_pad, - make_tuple(make_pass_through_transform(length), - make_unmerge_transform( - make_tuple(number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto v_dram_permuted = transform_tensor_view( - v_dram_unmerged, - make_tuple(make_xor_transform(make_tuple( - length, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_view( - v_dram_permuted, - make_tuple(make_pass_through_transform(length), - make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - }; - - const auto v_dram = [&]() { - { - return make_v_dram(v_ptr, kargs.seqlen_k); - } - }(); - - auto q_dram_window = make_tile_window( - q_dram, - [&]() { - if constexpr(FmhaPipeline::kQLoadOnce) - return make_tuple(number{}, - number{}); - else - return make_tuple(number{}, number{}); - }(), - {i_m0, 0}); - - auto k_dram_window = make_tile_window( - k_dram, - make_tuple(number{}, number{}), - {0, 0}); - - auto v_dram_window = make_tile_window( - v_dram, - make_tuple(number{}, number{}), - {0, 0}); - - /// FIXME: Before C++20, capturing structured binding variables are not supported. - /// Remove following copy capture of the 'i_nhead' if in C++20 - const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto bias_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - const BiasDataType* bias_ptr = - reinterpret_cast(kargs.bias_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_bias + - batch_offset_bias; - - const auto bias_dram = [&]() { - const auto bias_dram_naive = - make_naive_tensor_view( - bias_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_bias, 1), - number{}, - number<1>{}); - - return pad_tensor_view(bias_dram_naive, - bias_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(bias_dram_window_lengths); - } - }(); - - // lse acc - auto lse_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto lse_dram_window_lengths = make_tuple(number{}); - if constexpr(kStoreLSE) - { - LSEDataType* lse_ptr = - reinterpret_cast(kargs.lse_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_lse + - batch_offset_lse; - - const auto lse_dram = [&] { - const auto lse_dram_naive = [&] { - { - return make_naive_tensor_view( - lse_ptr, - make_tuple(kargs.seqlen_q), - make_tuple(1), - number<1>{}, - number<1>{}); - } - }(); - return pad_tensor_view( - lse_dram_naive, lse_dram_window_lengths, sequence{}); - }(); - - return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); - } - else - { - return make_null_tile_window(lse_dram_window_lengths); - } - }(); - - FmhaMask mask = [&]() { - if constexpr(kHasMask) - return ck_tile::make_generic_attention_mask_from_lr_window( - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); - else - return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; - }(); - - // WA i_batch capture structure binding before c++20 - auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - // data loading, shared by entire wg - // TODO: how to use s_read? - SaccDataType slope = - *(reinterpret_cast(kargs.alibi_slope_ptr) + - i_batch_ * kargs.alibi_slope_stride + i_nhead_); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; -#endif - if constexpr(kHasMask) - { - return make_alibi_from_lr_mask( - slope, - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type); - } - else - { - return Alibi{ - slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; - } - } - else - { - return EmptyPositionEncoding{}; - } - }(); - - auto o_acc_tile = [&]() { - if constexpr(PrefillCase) - { - // allocate double lds - // add __restrict__ here to avoid aliasing - __shared__ char smem_ptrk0 - [FmhaPipeline::Policy::template GetSmemSizeK()]; - __shared__ char smem_ptrk1 - [FmhaPipeline::Policy::template GetSmemSizeK()]; - __shared__ char smem_ptrv0[FmhaPipeline::Policy::template GetSmemSizeV< - typename FmhaPipeline::Problem>()]; - __shared__ char smem_ptrv1[FmhaPipeline::Policy::template GetSmemSizeV< - typename FmhaPipeline::Problem>()]; - - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - smem_ptrk0, - smem_ptrk1, - smem_ptrv0, - smem_ptrv1); - } - else - { - __shared__ char smem_ptr[GetSmemSize()]; - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - smem_ptr); - } - }(); - - // Oacc DRAM and Oacc DRAM window - auto o_dram = [&] { - const auto o_dram_naive = [&] { - { - return make_naive_tensor_view( - o_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.stride_o, 1), - number{}, - number<1>{}); - } - }(); - - return pad_tensor_view( - o_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - auto o_dram_window = make_tile_window( - o_dram, + return pad_tensor_view( + o_dram_naive, make_tuple(number{}, number{}), - {i_m0, i_n1}); + sequence{}); + }(); - EpiloguePipeline{}(o_dram_window, o_acc_tile); - } + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index f6a20c5cb5..aa2ec99590 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1038,7 +1038,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1096,7 +1096,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1190,7 +1190,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1249,7 +1249,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1344,7 +1344,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1379,7 +1379,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1490,7 +1490,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1589,7 +1589,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1623,7 +1623,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1667,7 +1667,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}; constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1718,7 +1718,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); pt_out.set_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), pt_warp_tensor.get_thread_buffer()); }); @@ -1768,7 +1768,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); dst_out.set_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), dst_warp_tensor.get_thread_buffer()); }); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index 45a1c8f4b8..cf70dff63f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -11,7 +11,6 @@ enum class BlockFmhaPipelineEnum QRKSVS = 0, QRKSVS_ASYNC, QSKSVS, - QRKSVS_ASYNC_TRLOAD, }; template @@ -33,10 +32,4 @@ struct BlockFmhaPipelineEnumToStr static constexpr const char* name = "qs"; }; -template <> -struct BlockFmhaPipelineEnumToStr -{ - static constexpr const char* name = "qr_async_trload"; -}; - } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 86ac713b6f..20b30b7417 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -22,7 +22,6 @@ template struct BlockFmhaPipelineProblem { @@ -47,7 +46,6 @@ struct BlockFmhaPipelineProblem static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr bool kUseTrLoad = kUseTrLoad_; // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp deleted file mode 100644 index 39d8814692..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ /dev/null @@ -1,1177 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp" -#include "ck_tile/ops/reduce/block/block_reduce.hpp" - -namespace ck_tile { - -// This pipeline is qkv all located in LDS -template -struct BlockFmhaPipelineQRKSVSAsyncTrload -{ - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using RandValOutputDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using AttentionVariant = remove_cvref_t; - using FmhaMask = remove_cvref_t; - - using BlockFmhaShape = remove_cvref_t; - using VLayout = remove_cvref_t; - static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once - static_assert(kQLoadOnce == Policy::QLoadOnce); - - static constexpr index_t kBlockSize = Problem::kBlockSize; - - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kN1 = BlockFmhaShape::kN1; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; - static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; - static constexpr index_t kNWarp = BlockFmhaShape::Gemm0BlockWarps::at(I1); - static constexpr index_t kNXdl = BlockFmhaShape::Gemm0WarpTile::at(I1); - - static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - - // static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && - // Problem::kPadHeadDimV == true); - - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = - Problem::kPadHeadDimQ; // support multiple of vector(like 8x) - static constexpr bool kPadHeadDimV = - Problem::kPadHeadDimV; // support multiple of vector(like 8x) - - static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; - static constexpr bool kHasDropout = Problem::kHasDropout; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kHasUnevenSplits = true; - - static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && - (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || - !kHasLogitsSoftCap)) || - (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); - - // last dimension vector length used to create tensor view(and decide buffer_load vector length) - // ... together with tensor distribution. tensor dist should able to overwrite this - static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); - static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); - static constexpr index_t kAlignmentV = []() { - if constexpr(std::is_same_v) - return Policy::template GetAlignmentV(); - else - return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); - }(); - - static constexpr index_t kAlignmentOacc = Policy::template GetAlignmentO(); - - static constexpr index_t kAlignmentBias = - kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); - - static constexpr index_t kBlockPerCu = []() { - if constexpr(Problem::kBlockPerCu != -1) - return Problem::kBlockPerCu; - else - { - if constexpr(kQKHeaddim <= 32) - { - return 2; - } - else if constexpr(kQKHeaddim <= 64) - { - return 3; - } - else if constexpr(kQKHeaddim <= 128) - { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || kM0 >= 256) - return 1; - else - return 2; - } - else if constexpr(kQKHeaddim <= 256) - { - return 1; - } - else - { - return 1; - } - } - }(); - - static constexpr const char* name = "qr_async_trload"; - - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - // Decode - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile - FmhaMask mask, - PositionEncoding position_encoding, - float scale_s, - void* smem_ptr) const - { - static_assert( - std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[I0] && - kSubQKHeaddim == QDramBlockWindowTmp{}.get_window_lengths()[I1] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[I0] && - kK0 == KDramBlockWindowTmp{}.get_window_lengths()[I1] && - kN1 == VDramBlockWindowTmp{}.get_window_lengths()[I0] && - kK1 == VDramBlockWindowTmp{}.get_window_lengths()[I1] && - kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I0] && - kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I1], - "wrong!"); - ignore = bias_dram_block_window_tmp; - ignore = position_encoding; - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); - - using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); - auto s_acc = SaccBlockTileType{}; - - // reduction function for softmax - const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; - const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; - - using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); - - auto o_acc = OaccBlockTileType{}; - - // infer Sacc, S, P, M, L, Oacc type - using SBlockTileType = decltype(cast_tile(o_acc)); - - using MLBlockTileType = decltype(block_tile_reduce( - SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); - - // init M, L - auto m = MLBlockTileType{}; - auto l = MLBlockTileType{}; - - clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); - - const auto q_origin = q_dram_block_window_tmp.get_window_origin(); - const auto [logical_seqlen_k_start, logical_seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(I0), number{}, number{}); - - // check early exit if no work to do - if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) - { - const index_t logical_num_total_loop = - integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0); - if(logical_num_total_loop <= 0) - { - if constexpr(kStoreLSE) - { - auto lse_acc = - make_static_distributed_tensor(m.get_tile_distribution()); - - set_tile(lse_acc, -numeric::infinity()); - - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } - } - - // Note: here occ are all cleard, return it - // Note: q loaded but no fence, ignore it. - return o_acc; - } - } - - // Q tile in LDS - auto q_dram_window = make_tile_window( - q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution()); - - auto q_lds_write_view = make_tensor_view( - static_cast(smem_ptr), Policy::template MakeQLdsBlockDescriptor()); - - auto q_lds_read_view = make_tensor_view( - static_cast(smem_ptr), - Policy::template MakeQLdsBlockDescriptor()); - - auto q_lds_store_window = - make_tile_window(q_lds_write_view, - Policy::template MakeQLdsBlockDescriptor().get_lengths(), - {0, 0}); - - auto q_lds_read_window = - make_tile_window(q_lds_read_view, - Policy::template MakeQLdsBlockDescriptor().get_lengths(), - {0, 0}, - Policy::template MakeQRegTileDistribution()); - - async_load_tile(q_lds_store_window, q_dram_window); - - // K tile in LDS - const index_t physical_seqlen_k_start = logical_seqlen_k_start; - const index_t physical_seqlen_k_end = logical_seqlen_k_end; - // make sure the first tile is completely located in page-block (page-block size should be - // divisible by kN0) - // relationship between each *_start variables: aligned_physical_seqlen_k_start <= - // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start - const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start; - - auto k_dram_window = make_tile_window( - k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); - - auto k_lds_write_view = make_tensor_view( - static_cast(smem_ptr), Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_read_view = make_tensor_view( - static_cast(smem_ptr), - Policy::template MakeKLdsBlockDescriptor()); - - auto k_lds_write_window = - make_tile_window(k_lds_write_view, - Policy::template MakeKLdsBlockDescriptor().get_lengths(), - {0, 0}); - auto k_lds_read_window = - make_tile_window(k_lds_read_view, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeKRegTileDistribution()); - - // S tile in LDS - auto s_lds = make_tensor_view( - reinterpret_cast(reinterpret_cast(smem_ptr) + - Policy::template GetSmemSizeK()), - Policy::template MakeSLdsBlockDescriptor()); - auto s_write_lds_window = make_tile_window( - s_lds, Policy::template MakeSLdsBlockDescriptor().get_lengths(), {0, 0}); - auto s_read_lds_window = - make_tile_window(s_lds, - Policy::template MakeSLdsBlockDescriptor().get_lengths(), - {0, 0}, - Policy::template MakeSRegTileDistribution()); - - // V tile in LDS - auto v_dram_window = make_tile_window( - v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); - - auto v_lds_write_view = make_tensor_view( - reinterpret_cast(static_cast(smem_ptr) + - Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeS()), - Policy::template MakeVLdsBlockDescriptor()); - auto v_lds_read_view = make_tensor_view( - reinterpret_cast(static_cast(smem_ptr) + - Policy::template GetSmemSizeK() + - Policy::template GetSmemSizeS()), - Policy::template MakeVLdsBlockDescriptor()); - auto v_lds_write_window = - make_tile_window(v_lds_write_view, - Policy::template MakeVLdsBlockDescriptor().get_lengths(), - {0, 0}); - - auto v_lds_read_window = - make_tile_window(v_lds_read_view, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeVRegTileDistribution()); - - block_sync_lds_direct_load<0>(); - auto q_tile = load_tile(q_lds_read_window); - - const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); - - index_t i_total_loops = 0; - constexpr index_t k0_loops = kQKHeaddim / kK0; - constexpr index_t k1_loops = kN0 / kK1; - - static_assert(1 <= k0_loops); - static_assert(1 <= k1_loops); - - block_sync_lds(); - async_load_tile(k_lds_write_window, k_dram_window); - - constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access(); - constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access(); - - do - { - block_sync_lds(); - async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile - - // move V tile windows - move_tile_window(v_dram_window, {kN0, 0}); - - // STAGE 1, QK gemm - clear_tile(s_acc); // initialize C - - if constexpr(1 < k0_loops) - { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - if constexpr(i_k0 == 0) - { - block_sync_lds_direct_load(); - } - else - { - block_sync_lds_direct_load<0>(); - } - - auto k_tile = load_tile(k_lds_read_window); - - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_tile); - - // loop over along the [K]ey head dimension - move_tile_window(k_dram_window, {0, kK0}); - block_sync_lds(); - async_load_tile(k_lds_write_window, k_dram_window); - }); - // move back to the origin - move_tile_window(k_dram_window, {0, -kK0 * (k0_loops - 1)}); - } - - if constexpr(k0_loops == 1) - { - block_sync_lds_direct_load(); - } - else - { - block_sync_lds_direct_load<0>(); - } - - auto k_tile = load_tile(k_lds_read_window); - - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_tile); - - if constexpr(kHasUnevenSplits) - { - if(i_total_loops == (num_total_loop - 1)) - { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); - set_tile_if(s_acc, - -numeric::infinity(), - [&, - physical_seqlen_k_start_ = physical_seqlen_k_start, - physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { - const auto col = k_origin.at(I0) + tile_idx.at(I1); - - { - return physical_seqlen_k_end_ <= col; - } - }); - } - } - - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) - { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); - - bool need_perpixel_check = - mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number{}, number{}); - if(need_perpixel_check) - { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(I0) + tile_idx.at(I0); - const auto col = k_origin.at(I0) + tile_idx.at(I1); - return mask.IsOutOfBound(row, col); - }); - } - } - - // move K tile windows after current status checked - // prefetch next-tile along [K]ey sequence length dimension - move_tile_window(k_dram_window, {kN0, 0}); - - block_sync_lds(); - async_load_tile(k_lds_write_window, k_dram_window); - - // Gemm1 - auto s_new = [&]() { - if constexpr(kNWarp > 1) - { - auto s = cast_tile(s_acc); // S{j} - - store_tile(s_write_lds_window, s); - block_sync_lds(); - return load_tile(s_read_lds_window); - } - else - { - return cast_tile(s_acc); // S{j} - } - }(); - - auto m_local = block_tile_reduce( - s_new, - sequence<1>{}, - f_max, - -numeric::infinity()); // m_local = rowmax(S{j}) - // Set CrossWarp to false will trigger better strategy on gfx950, but will cause - // performance regression because of un-coexecutable packed math, silent it for now - block_tile_reduce_sync( - m_local, f_max, bool_constant{} /*, bool_constant{}*/); - - const auto m_old = m; // m{j-1} - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} - - auto p_compute = make_static_distributed_tensor( - s_new.get_tile_distribution()); // Pcompute{j} - - static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_m == -numeric::infinity() - ? type_convert(0.f) - : raw_m; - } - else - { - return raw_m; - } - }; - - constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); - sweep_tile_span(p_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - auto row_max = scale_s * get_validated_m(m[i_idx]); - sweep_tile_span(p_spans[I1], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); - } - } - }); - }); - - auto rowsum_p = block_tile_reduce( - p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - - block_tile_reduce_sync( - rowsum_p, f_sum, bool_constant{} /*, bool_constant{}*/); - - auto p_tile = make_static_distributed_tensor( - Policy::template MakePRegTileDistribution()); - p_tile.get_thread_buffer() = cast_tile(p_compute).get_thread_buffer(); - - // l{j}, Oacc{j} - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); - } - } - }(); - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[I1], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - o_acc(i_j_idx) *= tmp; - }); - }); - - block_sync_lds_direct_load(); - - auto v_tile = load_tile_transpose(v_lds_read_window); - - if constexpr(1 < k1_loops) - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - gemm_1(o_acc, - get_slice_tile(p_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}), - v_tile); - - // loop over along the [V]alue Sequence length - move_tile_window(v_lds_read_window, {kK1, 0}); - v_tile = load_tile_transpose(v_lds_read_window); - }); - // move back to the origin - move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); - } - - gemm_1(o_acc, - get_slice_tile(p_tile, - sequence<0, (k1_loops - 1) * kK1>{}, - sequence{}), - v_tile); - - } while(++i_total_loops < num_total_loop); - - if constexpr(kStoreLSE) - { - // store lse acc - auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - - constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); - sweep_tile_span(lse_acc_spans[I0], [&, m_ = m, l_ = l](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); - } - else - { - lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); - } - } - }); - - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } - } - - // finally, O - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - - sweep_tile_span(o_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; - } - else - return 1 / l[i_idx]; - }(); - sweep_tile_span(o_spans[I1], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; - }); - }); - - return o_acc; - } - - // Prefill, double lds - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile - FmhaMask mask, - PositionEncoding position_encoding, - float scale_s, - void* __restrict__ smem_ptrk0, - void* __restrict__ smem_ptrk1, - void* __restrict__ smem_ptrv0, - void* __restrict__ smem_ptrv1) const - { - static_assert( - std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[I0] && - kSubQKHeaddim == QDramBlockWindowTmp{}.get_window_lengths()[I1] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[I0] && - kK0 == KDramBlockWindowTmp{}.get_window_lengths()[I1] && - kN1 == VDramBlockWindowTmp{}.get_window_lengths()[I0] && - kK1 == VDramBlockWindowTmp{}.get_window_lengths()[I1] && - kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I0] && - kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I1], - "wrong!"); - ignore = bias_dram_block_window_tmp; - ignore = position_encoding; - - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); - - using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); - auto s_acc = SaccBlockTileType{}; - - // reduction function for softmax - const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; - const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; - - using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); - - auto o_acc = OaccBlockTileType{}; - - // infer Sacc, S, P, M, L, Oacc type - using SBlockTileType = decltype(cast_tile(o_acc)); - - using MLBlockTileType = decltype(block_tile_reduce( - SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); - - // init M, L - auto m = MLBlockTileType{}; - auto l = MLBlockTileType{}; - - clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); - - const auto q_origin = q_dram_block_window_tmp.get_window_origin(); - const auto [logical_seqlen_k_start, logical_seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(I0), number{}, number{}); - - // check early exit if no work to do - if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) - { - const index_t logical_num_total_loop = - integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0); - if(logical_num_total_loop <= 0) - { - if constexpr(kStoreLSE) - { - auto lse_acc = - make_static_distributed_tensor(m.get_tile_distribution()); - - set_tile(lse_acc, -numeric::infinity()); - - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } - } - - // Note: here occ are all cleard, return it - // Note: q loaded but no fence, ignore it. - return o_acc; - } - } - - // Q tile in LDS - auto q_dram_window = make_tile_window( - q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution()); - - auto q_lds_write_view = make_tensor_view( - static_cast(smem_ptrk0), - Policy::template MakeQLdsBlockDescriptor()); - - auto q_lds_read_view = make_tensor_view( - static_cast(smem_ptrk0), - Policy::template MakeQLdsBlockDescriptor()); - - auto q_lds_store_window = - make_tile_window(q_lds_write_view, - Policy::template MakeQLdsBlockDescriptor().get_lengths(), - {0, 0}); - - auto q_lds_read_window = - make_tile_window(q_lds_read_view, - Policy::template MakeQLdsBlockDescriptor().get_lengths(), - {0, 0}, - Policy::template MakeQRegTileDistribution()); - - async_load_tile(q_lds_store_window, q_dram_window); - block_sync_lds_direct_load<0>(); - auto q_tile = load_tile(q_lds_read_window); - - // K tile in LDS - const index_t physical_seqlen_k_start = logical_seqlen_k_start; - const index_t physical_seqlen_k_end = logical_seqlen_k_end; - // make sure the first tile is completely located in page-block (page-block size should be - // divisible by kN0) - // relationship between each *_start variables: aligned_physical_seqlen_k_start <= - // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start - const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start; - - auto k_dram_window = make_tile_window( - k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); - - auto k_lds_write_view = make_tensor_view( - static_cast(smem_ptrk0), - Policy::template MakeKLdsBlockDescriptor()); - - auto k_lds_read_view = make_tensor_view( - static_cast(smem_ptrk0), - Policy::template MakeKLdsBlockDescriptor()); - - auto k_lds_write_window = - make_tile_window(k_lds_write_view, - Policy::template MakeKLdsBlockDescriptor().get_lengths(), - {0, 0}); - - auto k_lds_read_window = - make_tile_window(k_lds_read_view, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeKRegTileDistribution()); - - // S tile in LDS - auto s_lds = make_tensor_view( - reinterpret_cast(reinterpret_cast(smem_ptrk0) + - Policy::template GetSmemSizeK()), - Policy::template MakeSLdsBlockDescriptor()); - auto s_write_lds_window = make_tile_window( - s_lds, Policy::template MakeSLdsBlockDescriptor().get_lengths(), {0, 0}); - auto s_read_lds_window = - make_tile_window(s_lds, - Policy::template MakeSLdsBlockDescriptor().get_lengths(), - {0, 0}, - Policy::template MakeSRegTileDistribution()); - - // V tile in LDS - auto v_dram_window = make_tile_window( - v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); - - auto v_lds_write_view = make_tensor_view( - reinterpret_cast(static_cast(smem_ptrv0)), - Policy::template MakeVLdsBlockDescriptor()); - - auto v_lds_read_view = make_tensor_view( - reinterpret_cast(static_cast(smem_ptrv0)), - Policy::template MakeVLdsBlockDescriptor()); - - auto v_lds_write_window = - make_tile_window(v_lds_write_view, - Policy::template MakeVLdsBlockDescriptor().get_lengths(), - {0, 0}); - - auto v_lds_read_window = - make_tile_window(v_lds_read_view, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeVRegTileDistribution()); - - // block_sync_lds_direct_load<0>(); - // auto q_tile = load_tile(q_lds_read_window); - - const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); - - index_t i_total_loops = 0; - constexpr index_t k0_loops = kQKHeaddim / kK0; - constexpr index_t k1_loops = kN0 / kK1; - - static_assert(1 <= k0_loops); - static_assert(1 <= k1_loops); - block_sync_lds<0>(); - async_load_tile(k_lds_write_window, k_dram_window); - async_load_tile(v_lds_write_window, v_dram_window); - - move_tile_window(k_dram_window, {kN0, 0}); - k_lds_write_window.set_bottom_tensor_view_data_ptr( - static_cast(smem_ptrk1)); - async_load_tile(k_lds_write_window, k_dram_window); - - constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access(); - constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access(); - - constexpr index_t k_lds_insts = k_lds_read_window.get_num_of_access(); - constexpr index_t v_lds_insts = v_lds_read_window.get_num_of_access(); - - block_sync_lds_direct_load(); - auto k_tile = load_tile(k_lds_read_window); - - __builtin_amdgcn_sched_barrier(0); - - auto mainloop = [&](index_t cur_loop) { - const bool is_even_loop = (cur_loop % 2 == 0); - - auto k_lds_write_ptr = is_even_loop ? static_cast(smem_ptrk0) - : static_cast(smem_ptrk1); - auto k_lds_read_ptr = is_even_loop ? static_cast(smem_ptrk1) - : static_cast(smem_ptrk0); - auto v_lds_write_ptr = is_even_loop ? static_cast(smem_ptrv1) - : static_cast(smem_ptrv0); - auto v_lds_read_ptr = is_even_loop ? static_cast(smem_ptrv0) - : static_cast(smem_ptrv1); - - // move V tile windows - block_sync_lds(); - move_tile_window(v_dram_window, {kN0, 0}); - v_lds_write_window.set_bottom_tensor_view_data_ptr(v_lds_write_ptr); - async_load_tile(v_lds_write_window, v_dram_window); - - // STAGE 1, QK gemm - clear_tile(s_acc); // initialize C - - if constexpr(1 < k0_loops) - { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - // loop over along the [K]ey head dimension - move_tile_window(k_lds_read_window, {0, kK0}); - auto k_tile_switch = load_tile(k_lds_read_window); - - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_tile); - - k_tile = k_tile_switch; - }); - // move back to the origin - move_tile_window(k_lds_read_window, {0, -kK0 * (k0_loops - 1)}); - } - - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_tile); - - block_sync_lds_direct_load(); - v_lds_read_window.set_bottom_tensor_view_data_ptr(v_lds_read_ptr); - auto v_tile = load_tile_transpose(v_lds_read_window); - - if constexpr(kHasUnevenSplits) - { - if(i_total_loops == (num_total_loop - 1)) - { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); - set_tile_if(s_acc, - -numeric::infinity(), - [&, - physical_seqlen_k_start_ = physical_seqlen_k_start, - physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { - const auto col = k_origin.at(I0) + tile_idx.at(I1); - - { - return physical_seqlen_k_end_ <= col; - } - }); - } - } - - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) - { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); - - bool need_perpixel_check = - mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number{}, number{}); - if(need_perpixel_check) - { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(I0) + tile_idx.at(I0); - const auto col = k_origin.at(I0) + tile_idx.at(I1); - return mask.IsOutOfBound(row, col); - }); - } - } - - // Gemm1 - auto s_new = [&]() { - if constexpr(kNWarp > 1) - { - auto s = cast_tile(s_acc); // S{j} - - store_tile(s_write_lds_window, s); - block_sync_lds(); - return load_tile(s_read_lds_window); - } - else - { - return cast_tile(s_acc); // S{j} - } - }(); - - auto m_local = block_tile_reduce( - s_new, - sequence<1>{}, - f_max, - -numeric::infinity()); // m_local = rowmax(S{j}) - block_tile_reduce_sync( - m_local, f_max, bool_constant{} /*, bool_constant{}*/); - - static_for<0, 12, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ - }); - - static_for<0, 4, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ - }); - - const auto m_old = m; // m{j-1} - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} - - auto p_compute = make_static_distributed_tensor( - s_new.get_tile_distribution()); // Pcompute{j} - - static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_m == -numeric::infinity() - ? type_convert(0.f) - : raw_m; - } - else - { - return raw_m; - } - }; - - constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); - sweep_tile_span(p_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - auto row_max = scale_s * get_validated_m(m[i_idx]); - sweep_tile_span(p_spans[I1], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); - } - } - }); - }); - - auto rowsum_p = block_tile_reduce( - p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - - block_tile_reduce_sync( - rowsum_p, f_sum, bool_constant{} /*, bool_constant{}*/); - - auto p_tile = make_static_distributed_tensor( - Policy::template MakePRegTileDistribution()); - p_tile.get_thread_buffer() = cast_tile(p_compute).get_thread_buffer(); - - // l{j}, Oacc{j} - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); - } - } - }(); - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[I1], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - o_acc(i_j_idx) *= tmp; - }); - }); - - block_sync_lds(); - move_tile_window(k_dram_window, {kN0, 0}); - k_lds_write_window.set_bottom_tensor_view_data_ptr(k_lds_write_ptr); - async_load_tile(k_lds_write_window, k_dram_window); - - if constexpr(1 < k1_loops) - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - // loop over along the [V]alue Sequence length - move_tile_window(v_lds_read_window, {kK1, 0}); - auto v_tile_switch = load_tile_transpose(v_lds_read_window); - - gemm_1(o_acc, - get_slice_tile(p_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}), - v_tile); - - v_tile = v_tile_switch; - }); - // move back to the origin - move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); - } - - gemm_1(o_acc, - get_slice_tile(p_tile, - sequence<0, (k1_loops - 1) * kK1>{}, - sequence{}), - v_tile); - - block_sync_lds_direct_load(); - k_lds_read_window.set_bottom_tensor_view_data_ptr(k_lds_read_ptr); - k_tile = load_tile(k_lds_read_window); - - static_for<0, 12, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ - }); - - static_for<0, 4, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ - }); - }; - - do - { - mainloop(i_total_loops); - i_total_loops++; - } while(i_total_loops < num_total_loop); - - if constexpr(kStoreLSE) - { - // store lse acc - auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - - constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); - sweep_tile_span(lse_acc_spans[I0], [&, m_ = m, l_ = l](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); - } - else - { - lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); - } - } - }); - - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } - } - - // finally, O - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - - sweep_tile_span(o_spans[I0], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; - } - else - return 1 / l[i_idx]; - }(); - sweep_tile_span(o_spans[I1], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; - }); - }); - - return o_acc; - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp deleted file mode 100644 index ed22758566..0000000000 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp +++ /dev/null @@ -1,823 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" -#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp" - -// can remove all bank conflicts, but drop the performance for some cases -// Probably it is limited by compiler optimization. -#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 -namespace ck_tile { -// This pipeline is qkv all located in LDS -struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy - : BlockFmhaPipelineQXKSVSCustomPolicy -{ - using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy; - - template - CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); - - // this should align with MakeQDramTileDistribution() - constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - return min(ElemPerThread, MaxVectorSize); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() - { - using OaccDataType = remove_cvref_t; - - return static_cast(16 / sizeof(OaccDataType)); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType); - - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - return min(ElemPerThread, MaxVectorSize); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; - - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType); - - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - return min(ElemPerThread, MaxVectorSize); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() - { - if constexpr(!BypassLDS) - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); - - constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - - constexpr index_t KPerThread = kMaxVecLoad; - constexpr index_t KThreads = kKPerBlock / KPerThread; - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - else - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - constexpr auto q_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<2, 1>, - sequence<0, 0>>{}; - - constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); - - return q_block_dstr; - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() - { - using KDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = - LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0; - - constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - - constexpr index_t K1 = min(MaxVectorSize, ElemPerThread); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - // Read M first, then K - // This is the same data consume order as BlockGEMM - constexpr auto q_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); - - return q_block_dstr; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() - { - // TODO: this is for 3d layout - using QDataType = remove_cvref_t; - return static_cast(16 / sizeof(QDataType)); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - - constexpr index_t kKPack = GetSmemKPackQ(); - - constexpr auto q_lds_block_desc = [&]() { - if constexpr(Xor) - { -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::QDataType); - constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock; - - if constexpr(XorLengthFold > 1) - { - constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor( - q_lds_block_desc_naive, - make_tuple( - make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - constexpr auto q_lds_block_desc_tmp = transform_tensor_descriptor( - q_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_descriptor( - q_lds_block_desc_tmp, - make_tuple( - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor( - make_tuple( - number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor( - q_lds_block_desc_naive, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_descriptor( - q_lds_block_desc_permuted, - make_tuple(make_pass_through_transform(number{}), - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - } - else - { - return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number{}, - number<1>{}); - } - }(); - - return q_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = - LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0; - - constexpr index_t kKPack = GetSmemKPackK(); - - constexpr auto k_lds_block_desc = [&]() { - if constexpr(Xor) - { -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::KDataType); - constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock; - - if constexpr(XorLengthFold > 1) - { - constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor( - k_lds_block_desc_naive, - make_tuple( - make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - constexpr auto k_lds_block_desc_tmp = transform_tensor_descriptor( - k_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_descriptor( - k_lds_block_desc_tmp, - make_tuple( - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor( - make_tuple( - number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor( - k_lds_block_desc_naive, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_descriptor( - k_lds_block_desc_permuted, - make_tuple(make_pass_through_transform(number{}), - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - } - else - { - return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number{}, - number<1>{}); - } - }(); - - return k_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; - - constexpr index_t kKPack = GetSmemKPackV(); - - constexpr auto v_lds_block_desc = [&]() { - if constexpr(Xor) - { - constexpr auto XorGroupSize = - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); - -#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::VDataType); - constexpr auto XorLengthFold = LDSLayerSize / kNPerBlock; - - if constexpr(XorLengthFold > 1) - { - constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor( - v_lds_block_desc_naive, - make_tuple( - make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - constexpr auto v_lds_block_desc_tmp = transform_tensor_descriptor( - v_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - return transform_tensor_descriptor( - v_lds_block_desc_tmp, - make_tuple( - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{})), - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{}))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - else -#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - { - constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor( - v_lds_block_desc_naive, - make_tuple(make_xor_transform(make_tuple( - number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - return transform_tensor_descriptor( - v_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform_v3_division_mod(make_tuple( - number{}, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - } - } - else - { - return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number{}, - number<1>{}); - } - }(); - - return v_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() - { - using GemmProblem = - BlockGemmProblem, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>, - GemmLoopOrder::MNK>; - - using WarpGemm = - WarpGemmMfmaDispatcher{}), - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), - true>; - - using BlockGemmPolicy = - BlockGemmARegBRegCRegV1CustomPolicy; - - return BlockGemmARegBRegCRegV1{}; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetPVBlockGemm() - { - using GemmProblem = - BlockGemmProblem, - typename Problem::BlockFmhaShape::Gemm1BlockWarps, - typename Problem::BlockFmhaShape::Gemm1WarpTile>, - GemmLoopOrder::KMN>; - - using WarpGemm = WarpGemmMfmaDispatcher< - typename Problem::PDataType, - typename Problem::VDataType, - typename Problem::OaccDataType, - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), - true, - false, - false, - ((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 && - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) || - (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 && - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16)) - ? WGAttrNumAccessEnum::Double - : WGAttrNumAccessEnum::Single>; - - using BlockGemmPolicy = - BlockGemmARegBRegCRegV1CustomPolicy; - - return BlockGemmARegBRegCRegV1{}; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKRegTileDistribution() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - // Read N first, then K - // This is the same data consume order as BlockGEMM - constexpr auto k_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); - - return k_block_dstr; - } - - template - CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; - - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType); - - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - - constexpr index_t NPerThread = kMaxVecLoad; - constexpr index_t NThreads = kNPerBlock / NPerThread; - constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t KPerThread = kKPerBlock / (KThreadPerWarp * NumWarps); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakePRegTileDistribution() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - // Read M first, then K - // This is the same data consume order as BlockGEMM - constexpr auto p_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<2, 1>, - sequence<0, 0>>{}; - - constexpr auto p_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - p_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - constexpr auto p_block_dstr = make_static_tile_distribution(p_block_dstr_encode); - - return p_block_dstr; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeVRegTileDistribution() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - // Read N first, then K - // This is the same data consume order as BlockGEMM - constexpr auto v_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<2, 1>, - sequence<0, 0>>{}; - - constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - constexpr auto v_block_dstr = - make_static_tile_distribution(typename InputTileDistributionTraits< - decltype(v_block_dstr_encode), - typename Problem::VDataType>::TransposedDstrEncode{}); - - return v_block_dstr; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS() - { - using SDataType = remove_cvref_t; - return static_cast(16 / sizeof(SDataType)); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeSLdsBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kNPack = GetSmemNPackS(); - - constexpr auto s_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number<(kMPerBlock + 1) * kNPack>{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto s_lds_block_desc = transform_tensor_descriptor( - s_lds_block_desc_0, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return s_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeSRegTileDistribution() - { - using BlockGemm = remove_cvref_t())>; - - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - // static_assert(MWarp == 1, "Check failed!"); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kTileK = Problem::BlockFmhaShape::kN0; - - // K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm - constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K1 = kKPerBlock / (K2 * K3); - constexpr index_t K0 = kTileK / kKPerBlock; - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - constexpr auto s2_block_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<2, 2>>, - sequence<1, 2, 2, 2>, - sequence<0, 0, 1, 3>>{}; - - constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding); - - return s2_block_dstr; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() - { - return MakeQLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::QDataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() - { - return MakeKLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::KDataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() - { - return MakeVLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::VDataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeS() - { - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); - - return NWarp > 1 ? MakeSLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::SaccDataType) - : 0; - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() - { - // Alignment on gfx950 is 1280 Bytes - // Alignment before gfx950 is 512 Bytes. - return max(GetSmemSizeQ(), - GetSmemSizeK() + GetSmemSizeS() + GetSmemSizeV()); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index e2cea97f9a..3489d6f9a1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -383,31 +383,23 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t kMaxVecLoad = - min(total_pixels, static_cast(16 / sizeof(VDataType))); - - return kMaxVecLoad; + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); } template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { - using VLayout = remove_cvref_t; - using VDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t kMaxVecLoad = - min(total_pixels, static_cast(16 / sizeof(VDataType))); - + using VLayout = remove_cvref_t; + using VDataType = remove_cvref_t; if constexpr(std::is_same_v) { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t kMaxVecLoad = + min(total_pixels, static_cast(16 / sizeof(VDataType))); constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) @@ -418,7 +410,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - using WarpGemm = typename Traits::WarpGemm; - using BlockGemmShape = typename Traits::BlockGemmShape; - static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder; + using WarpGemm = typename Traits::WarpGemm; + using BlockGemmShape = typename Traits::BlockGemmShape; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -89,36 +86,17 @@ struct BlockGemmARegBRegCRegV1 } else { - if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) - { - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<2, 1>, - sequence<0, 0>>{}; + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; - } - else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) - { - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; - } + return a_block_dstr_encode; } } @@ -140,33 +118,17 @@ struct BlockGemmARegBRegCRegV1 } else { - if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) - { - constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<2, 1>, - sequence<0, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - return b_block_dstr_encode; - } - else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) - { - constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - return b_block_dstr_encode; - } + return b_block_dstr_encode; } } @@ -251,82 +213,40 @@ struct BlockGemmARegBRegCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) - { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - // read C warp tensor from C block tensor - using c_iter_idx = std::conditional_t, - sequence>; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - }); - } - else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) - { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + // read C warp tensor from C block tensor + using c_iter_idx = std:: + conditional_t, sequence>; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); }); }); - } + }); } CK_TILE_DEVICE static constexpr auto MakeCBlockTile() diff --git a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp index d0be065fc9..fd5211a59a 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { @@ -14,8 +13,7 @@ template + index_t NumWaveGroups_ = 1> struct BlockGemmProblem { using ADataType = remove_cvref_t; @@ -23,9 +21,8 @@ struct BlockGemmProblem using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t NumWaveGroups = NumWaveGroups_; - static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t NumWaveGroups = NumWaveGroups_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index b3c86b9456..b18bf603a9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -39,12 +39,6 @@ enum struct TailNumber Full, }; -enum struct GemmLoopOrder -{ - KMN, - MNK, -}; - } // namespace ck_tile inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index c628614b54..52bd07c9e2 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -14,11 +14,10 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> struct GemmPipelineProblemBase { using Traits = remove_cvref_t; @@ -46,10 +45,9 @@ struct GemmPipelineProblemBase static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadK = Traits::kPadK; - static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; - static constexpr auto Scheduler = GemmPipelineScheduler::Default; - static constexpr index_t VectorLoadSize = Traits::_VectorSize; - static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; + static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; + static constexpr auto Scheduler = GemmPipelineScheduler::Default; + static constexpr index_t VectorLoadSize = Traits::_VectorSize; // In the base situation, the Preshuffle setting should be false. static constexpr bool Preshuffle = false; @@ -169,11 +167,10 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> using GemmPipelineProblem = GemmPipelineProblemBase; + VectorSizeB_>; template + GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, + bool HasHotLoop_ = true, + TailNumber TailNum_ = TailNumber::Full, + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> struct UniversalGemmPipelineProblem { using Traits = remove_cvref_t; @@ -229,9 +224,8 @@ struct UniversalGemmPipelineProblem static constexpr auto Scheduler = Scheduler_; static constexpr bool Preshuffle = Traits::Preshuffle; - static constexpr index_t VectorSizeA = VectorSizeA_; - static constexpr index_t VectorSizeB = VectorSizeB_; - static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; + static constexpr index_t VectorSizeA = VectorSizeA_; + static constexpr index_t VectorSizeB = VectorSizeB_; static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index d1deaf9e0e..fb191d565d 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -104,10 +104,6 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = 1>>; #endif -using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution = - WarpGemmImpl>>; - #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>; #endif -using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution = - WarpGemmImpl>>; - #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp16 2:4 structural sparsity // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity @@ -76,8 +74,6 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 7a10d1fa56..434be9f84a 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -14,14 +14,10 @@ namespace ck_tile { * Y dim must have at least one dim not been reduced */ // synchronize reduce result (cross lane reduction and broadcast on replicated dimension) -template +template CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, const ReduceFunc& reduce_func, - bool_constant = {}, - bool_constant = {}) + bool_constant = {}) { using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using DstrEncode = typename Dstr::DstrEncode; @@ -60,24 +56,14 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, // reduction sweep forward static_for<0, nstage, 1>{}([&](auto istage) { - if constexpr(CrossWarp) - { - constexpr index_t lid_delta = - lid_over_rid_derivative * (1 << (nstage - istage - 1)); + constexpr index_t lid_delta = + lid_over_rid_derivative * (1 << (nstage - istage - 1)); - // pull data from remote lane - const auto v_remote = warp_shuffle_down(v_local, lid_delta); + // pull data from remote lane + const auto v_remote = warp_shuffle_down(v_local, lid_delta); - // reduce - v_local = reduce_func(v_local, v_remote); - } - else - { - // pull data from remote lane - const auto v_swapped_regs = warp_shuffle_down_pair(v_local); - // reduce - v_local = reduce_func(v_swapped_regs.at(0), v_swapped_regs.at(1)); - } + // reduce + v_local = reduce_func(v_local, v_remote); }); } });